File size: 3,760 Bytes
deaa21d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
PPE Compliance Detection - Inference Script
Usage with trained YOLOv8 model
"""
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from PIL import Image
import cv2
import numpy as np

MODEL_ID = "baskarmother/yolov8-ppe-construction"

# PPE compliance pairs: (required_ppe, violation_class)
PPE_PAIRS = {
    "hardhat": ("hardhat", "no-hardhat"),
    "mask": ("mask", "no-mask"),
    "safety_vest": ("safety vest", "no-safety vest"),
}

CLASS_NAMES = [
    'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask',
    'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net',
    'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader'
]

def load_model():
    """Load model from HuggingFace Hub."""
    weights_path = hf_hub_download(MODEL_ID, "best.pt")
    model = YOLO(weights_path)
    return model

def detect(image_path, conf_threshold=0.25):
    """Run detection on an image."""
    model = load_model()
    results = model(image_path, conf=conf_threshold)
    return results[0]

def check_compliance(result):
    """Check PPE compliance from detection results."""
    boxes = result.boxes
    detected_classes = set()
    
    for cls in boxes.cls:
        detected_classes.add(CLASS_NAMES[int(cls)])
    
    compliance_report = {}
    for ppe_name, (required, violation) in PPE_PAIRS.items():
        has_required = required in detected_classes
        has_violation = violation in detected_classes
        
        if has_violation:
            status = "VIOLATION"
        elif has_required:
            status = "COMPLIANT"
        else:
            status = "NOT DETECTED"
        
        compliance_report[ppe_name] = {
            "status": status,
            "required_detected": has_required,
            "violation_detected": has_violation
        }
    
    return compliance_report

def draw_detections(image_path, result, save_path="output.jpg"):
    """Draw bounding boxes and labels on image."""
    img = cv2.imread(image_path)
    boxes = result.boxes
    
    for box in boxes:
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        cls_id = int(box.cls[0])
        conf = float(box.conf[0])
        label = CLASS_NAMES[cls_id]
        
        # Color: green for compliant PPE, red for violations, blue for others
        if label in ["hardhat", "mask", "safety vest", "gloves", "safety shoes"]:
            color = (0, 255, 0)
        elif label in ["no-hardhat", "no-mask", "no-safety vest"]:
            color = (0, 0, 255)
        else:
            color = (255, 0, 0)
        
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
        text = f"{label} {conf:.2f}"
        cv2.putText(img, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    
    cv2.imwrite(save_path, img)
    return save_path

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--image", required=True, help="Path to input image")
    parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
    parser.add_argument("--output", default="output.jpg", help="Output image path")
    args = parser.parse_args()
    
    print("Loading model...")
    result = detect(args.image, args.conf)
    
    print(f"Detected {len(result.boxes)} objects")
    for box in result.boxes:
        cls_id = int(box.cls[0])
        conf = float(box.conf[0])
        print(f"  - {CLASS_NAMES[cls_id]}: {conf:.3f}")
    
    print("\n--- PPE Compliance Report ---")
    report = check_compliance(result)
    for ppe, info in report.items():
        print(f"  {ppe}: {info['status']}")
    
    draw_detections(args.image, result, args.output)
    print(f"\nAnnotated image saved to {args.output}")