baskarmother commited on
Commit
87e2c95
·
verified ·
1 Parent(s): deaa21d

Add Gradio demo app for PPE detection

Browse files
Files changed (1) hide show
  1. gradio_app.py +110 -0
gradio_app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for PPE Compliance Detection
3
+ Run: gradio gradio_app.py
4
+ """
5
+ import gradio as gr
6
+ from ultralytics import YOLO
7
+ from huggingface_hub import hf_hub_download
8
+ from PIL import Image
9
+ import cv2
10
+ import numpy as np
11
+
12
+ MODEL_ID = "baskarmother/yolov8-ppe-construction"
13
+ CLASS_NAMES = [
14
+ 'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask',
15
+ 'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net',
16
+ 'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader'
17
+ ]
18
+
19
+ def get_color(label):
20
+ if label in ["hardhat", "mask", "safety vest", "gloves", "safety shoes"]:
21
+ return (0, 255, 0) # Green = compliant PPE
22
+ elif label in ["no-hardhat", "no-mask", "no-safety vest"]:
23
+ return (0, 0, 255) # Red = violation
24
+ else:
25
+ return (255, 0, 0) # Blue = other
26
+
27
+ # Load model once
28
+ print("Loading model from Hub...")
29
+ weights_path = hf_hub_download(MODEL_ID, "best.pt")
30
+ model = YOLO(weights_path)
31
+ print("Model loaded!")
32
+
33
+ def detect_ppe(image, conf_threshold=0.25):
34
+ results = model(image, conf=conf_threshold)
35
+ result = results[0]
36
+
37
+ # Convert PIL to numpy array
38
+ img = np.array(image)
39
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
40
+
41
+ detections = []
42
+ compliance = {"hardhat": "NOT DETECTED", "mask": "NOT DETECTED", "safety_vest": "NOT DETECTED"}
43
+ detected_classes = set()
44
+
45
+ for box in result.boxes:
46
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
47
+ cls_id = int(box.cls[0])
48
+ conf = float(box.conf[0])
49
+ label = CLASS_NAMES[cls_id]
50
+ detected_classes.add(label)
51
+
52
+ color = get_color(label)
53
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
54
+ text = f"{label} {conf:.2f}"
55
+ cv2.putText(img, text, (x1, max(y1 - 10, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
56
+ detections.append(f"{label}: {conf:.2f}")
57
+
58
+ # Compliance check
59
+ if "hardhat" in detected_classes:
60
+ compliance["hardhat"] = "✅ COMPLIANT"
61
+ if "no-hardhat" in detected_classes:
62
+ compliance["hardhat"] = "❌ VIOLATION"
63
+ if "mask" in detected_classes:
64
+ compliance["mask"] = "✅ COMPLIANT"
65
+ if "no-mask" in detected_classes:
66
+ compliance["mask"] = "❌ VIOLATION"
67
+ if "safety vest" in detected_classes:
68
+ compliance["safety_vest"] = "✅ COMPLIANT"
69
+ if "no-safety vest" in detected_classes:
70
+ compliance["safety_vest"] = "❌ VIOLATION"
71
+
72
+ # Convert back to RGB
73
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
74
+ output_image = Image.fromarray(img_rgb)
75
+
76
+ report = f"""### PPE Compliance Report
77
+ - **Hard Hat**: {compliance['hardhat']}
78
+ - **Mask**: {compliance['mask']}
79
+ - **Safety Vest**: {compliance['safety_vest']}
80
+
81
+ ### Detections ({len(detections)} objects)
82
+ {chr(10).join(['- ' + d for d in detections]) if detections else 'No objects detected'}
83
+ """
84
+ return output_image, report
85
+
86
+ with gr.Blocks(title="PPE Compliance Detection") as demo:
87
+ gr.Markdown("# 🦺 PPE Compliance Detection for Construction Sites")
88
+ gr.Markdown("Upload a construction site image to detect Personal Protective Equipment (PPE) compliance.")
89
+
90
+ with gr.Row():
91
+ with gr.Column():
92
+ input_image = gr.Image(type="pil", label="Upload Image")
93
+ conf_slider = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="Confidence Threshold")
94
+ detect_btn = gr.Button("Detect PPE", variant="primary")
95
+
96
+ with gr.Column():
97
+ output_image = gr.Image(label="Detected Objects")
98
+ output_text = gr.Markdown(label="Compliance Report")
99
+
100
+ detect_btn.click(fn=detect_ppe, inputs=[input_image, conf_slider], outputs=[output_image, output_text])
101
+
102
+ gr.Markdown("""
103
+ ## Classes Detected
104
+ - ✅ **Compliant PPE**: hardhat, mask, safety vest, gloves, safety shoes
105
+ - ❌ **Violations**: no-hardhat, no-mask, no-safety vest
106
+ - 🔵 **Other**: person, equipment, vehicles, barricades
107
+ """)
108
+
109
+ if __name__ == "__main__":
110
+ demo.launch()