baskarmother commited on
Commit
deaa21d
·
verified ·
1 Parent(s): acff799

Add inference script for PPE detection

Browse files
Files changed (1) hide show
  1. inference.py +115 -0
inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PPE Compliance Detection - Inference Script
3
+ Usage with trained YOLOv8 model
4
+ """
5
+ from ultralytics import YOLO
6
+ from huggingface_hub import hf_hub_download
7
+ from PIL import Image
8
+ import cv2
9
+ import numpy as np
10
+
11
+ MODEL_ID = "baskarmother/yolov8-ppe-construction"
12
+
13
+ # PPE compliance pairs: (required_ppe, violation_class)
14
+ PPE_PAIRS = {
15
+ "hardhat": ("hardhat", "no-hardhat"),
16
+ "mask": ("mask", "no-mask"),
17
+ "safety_vest": ("safety vest", "no-safety vest"),
18
+ }
19
+
20
+ CLASS_NAMES = [
21
+ 'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask',
22
+ 'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net',
23
+ 'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader'
24
+ ]
25
+
26
+ def load_model():
27
+ """Load model from HuggingFace Hub."""
28
+ weights_path = hf_hub_download(MODEL_ID, "best.pt")
29
+ model = YOLO(weights_path)
30
+ return model
31
+
32
+ def detect(image_path, conf_threshold=0.25):
33
+ """Run detection on an image."""
34
+ model = load_model()
35
+ results = model(image_path, conf=conf_threshold)
36
+ return results[0]
37
+
38
+ def check_compliance(result):
39
+ """Check PPE compliance from detection results."""
40
+ boxes = result.boxes
41
+ detected_classes = set()
42
+
43
+ for cls in boxes.cls:
44
+ detected_classes.add(CLASS_NAMES[int(cls)])
45
+
46
+ compliance_report = {}
47
+ for ppe_name, (required, violation) in PPE_PAIRS.items():
48
+ has_required = required in detected_classes
49
+ has_violation = violation in detected_classes
50
+
51
+ if has_violation:
52
+ status = "VIOLATION"
53
+ elif has_required:
54
+ status = "COMPLIANT"
55
+ else:
56
+ status = "NOT DETECTED"
57
+
58
+ compliance_report[ppe_name] = {
59
+ "status": status,
60
+ "required_detected": has_required,
61
+ "violation_detected": has_violation
62
+ }
63
+
64
+ return compliance_report
65
+
66
+ def draw_detections(image_path, result, save_path="output.jpg"):
67
+ """Draw bounding boxes and labels on image."""
68
+ img = cv2.imread(image_path)
69
+ boxes = result.boxes
70
+
71
+ for box in boxes:
72
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
73
+ cls_id = int(box.cls[0])
74
+ conf = float(box.conf[0])
75
+ label = CLASS_NAMES[cls_id]
76
+
77
+ # Color: green for compliant PPE, red for violations, blue for others
78
+ if label in ["hardhat", "mask", "safety vest", "gloves", "safety shoes"]:
79
+ color = (0, 255, 0)
80
+ elif label in ["no-hardhat", "no-mask", "no-safety vest"]:
81
+ color = (0, 0, 255)
82
+ else:
83
+ color = (255, 0, 0)
84
+
85
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
86
+ text = f"{label} {conf:.2f}"
87
+ cv2.putText(img, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
88
+
89
+ cv2.imwrite(save_path, img)
90
+ return save_path
91
+
92
+ if __name__ == "__main__":
93
+ import argparse
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument("--image", required=True, help="Path to input image")
96
+ parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
97
+ parser.add_argument("--output", default="output.jpg", help="Output image path")
98
+ args = parser.parse_args()
99
+
100
+ print("Loading model...")
101
+ result = detect(args.image, args.conf)
102
+
103
+ print(f"Detected {len(result.boxes)} objects")
104
+ for box in result.boxes:
105
+ cls_id = int(box.cls[0])
106
+ conf = float(box.conf[0])
107
+ print(f" - {CLASS_NAMES[cls_id]}: {conf:.3f}")
108
+
109
+ print("\n--- PPE Compliance Report ---")
110
+ report = check_compliance(result)
111
+ for ppe, info in report.items():
112
+ print(f" {ppe}: {info['status']}")
113
+
114
+ draw_detections(args.image, result, args.output)
115
+ print(f"\nAnnotated image saved to {args.output}")