Nekshay commited on
Commit
993159a
·
verified ·
1 Parent(s): 22fd6e9

Create ROC_curve_TFlite_Model.py

Browse files
Files changed (1) hide show
  1. ROC_curve_TFlite_Model.py +100 -0
ROC_curve_TFlite_Model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import matplotlib.pyplot as plt
4
+ import xml.etree.ElementTree as ET
5
+ import cv2
6
+ import glob
7
+ import os
8
+ from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
9
+
10
+ # Define paths
11
+ test_images_dir = "test_images/" # Path to test images
12
+ test_annotations_dir = "test_annotations/" # Path to Pascal VOC XML files
13
+ tflite_model_path = "efficientdet_lite0.tflite"
14
+
15
+ # Load TFLite model
16
+ interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
17
+ interpreter.allocate_tensors()
18
+
19
+ # Function to run inference
20
+ def run_inference(interpreter, image):
21
+ input_details = interpreter.get_input_details()
22
+ output_details = interpreter.get_output_details()
23
+
24
+ # Preprocess image
25
+ input_shape = input_details[0]['shape']
26
+ image = cv2.resize(image, (input_shape[1], input_shape[2])) # Resize to model input size
27
+ image = image.astype(np.float32) / 255.0 # Normalize
28
+ image = np.expand_dims(image, axis=0) # Add batch dimension
29
+
30
+ # Set input tensor
31
+ interpreter.set_tensor(input_details[0]['index'], image)
32
+ interpreter.invoke()
33
+
34
+ # Get output (bounding boxes and scores)
35
+ output_data = interpreter.get_tensor(output_details[0]['index'])
36
+ return output_data # Confidence scores
37
+
38
+ # Function to parse Pascal VOC XML annotation
39
+ def parse_voc_annotation(xml_file):
40
+ tree = ET.parse(xml_file)
41
+ root = tree.getroot()
42
+
43
+ objects = root.findall("object")
44
+ return 1 if objects else 0 # If objects exist, return 1 (object present), else 0
45
+
46
+ # Load test images and annotations
47
+ image_files = glob.glob(os.path.join(test_images_dir, "*.jpg")) # Adjust if using .png
48
+ y_scores = []
49
+ y_true = []
50
+
51
+ for image_file in image_files:
52
+ # Load image
53
+ image = cv2.imread(image_file)
54
+
55
+ # Get corresponding XML annotation
56
+ xml_file = os.path.join(test_annotations_dir, os.path.splitext(os.path.basename(image_file))[0] + ".xml")
57
+ if not os.path.exists(xml_file):
58
+ continue # Skip if annotation is missing
59
+
60
+ # Get ground truth label (1 = object present, 0 = no object)
61
+ true_label = parse_voc_annotation(xml_file)
62
+
63
+ # Run inference
64
+ scores = run_inference(interpreter, image)
65
+ max_score = np.max(scores) # Get highest confidence score
66
+
67
+ # Append results
68
+ y_scores.append(max_score)
69
+ y_true.append(true_label)
70
+
71
+ # Convert to numpy arrays
72
+ y_scores = np.array(y_scores)
73
+ y_true = np.array(y_true)
74
+
75
+ # Compute ROC curve and AUC
76
+ fpr, tpr, _ = roc_curve(y_true, y_scores)
77
+ roc_auc = auc(fpr, tpr)
78
+
79
+ # Compute Precision-Recall curve and AP score
80
+ precision, recall, _ = precision_recall_curve(y_true, y_scores)
81
+ average_precision = average_precision_score(y_true, y_scores)
82
+
83
+ # Plot ROC Curve
84
+ plt.figure(figsize=(8, 6))
85
+ plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC Curve (AUC = {roc_auc:.2f})')
86
+ plt.plot([0, 1], [0, 1], color='gray', linestyle='--') # Diagonal line
87
+ plt.xlabel("False Positive Rate")
88
+ plt.ylabel("True Positive Rate")
89
+ plt.title("ROC Curve")
90
+ plt.legend()
91
+ plt.show()
92
+
93
+ # Plot Precision-Recall Curve
94
+ plt.figure(figsize=(8, 6))
95
+ plt.plot(recall, precision, color='green', lw=2, label=f'PR Curve (AP = {average_precision:.2f})')
96
+ plt.xlabel("Recall")
97
+ plt.ylabel("Precision")
98
+ plt.title("Precision-Recall Curve")
99
+ plt.legend()
100
+ plt.show()