| """ |
| PPE Compliance Detection - Inference Script |
| Usage with trained YOLOv8 model |
| """ |
| from ultralytics import YOLO |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| import cv2 |
| import numpy as np |
|
|
| MODEL_ID = "baskarmother/yolov8-ppe-construction" |
|
|
| |
| PPE_PAIRS = { |
| "hardhat": ("hardhat", "no-hardhat"), |
| "mask": ("mask", "no-mask"), |
| "safety_vest": ("safety vest", "no-safety vest"), |
| } |
|
|
| CLASS_NAMES = [ |
| 'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask', |
| 'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net', |
| 'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader' |
| ] |
|
|
| def load_model(): |
| """Load model from HuggingFace Hub.""" |
| weights_path = hf_hub_download(MODEL_ID, "best.pt") |
| model = YOLO(weights_path) |
| return model |
|
|
| def detect(image_path, conf_threshold=0.25): |
| """Run detection on an image.""" |
| model = load_model() |
| results = model(image_path, conf=conf_threshold) |
| return results[0] |
|
|
| def check_compliance(result): |
| """Check PPE compliance from detection results.""" |
| boxes = result.boxes |
| detected_classes = set() |
| |
| for cls in boxes.cls: |
| detected_classes.add(CLASS_NAMES[int(cls)]) |
| |
| compliance_report = {} |
| for ppe_name, (required, violation) in PPE_PAIRS.items(): |
| has_required = required in detected_classes |
| has_violation = violation in detected_classes |
| |
| if has_violation: |
| status = "VIOLATION" |
| elif has_required: |
| status = "COMPLIANT" |
| else: |
| status = "NOT DETECTED" |
| |
| compliance_report[ppe_name] = { |
| "status": status, |
| "required_detected": has_required, |
| "violation_detected": has_violation |
| } |
| |
| return compliance_report |
|
|
| def draw_detections(image_path, result, save_path="output.jpg"): |
| """Draw bounding boxes and labels on image.""" |
| img = cv2.imread(image_path) |
| boxes = result.boxes |
| |
| for box in boxes: |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) |
| cls_id = int(box.cls[0]) |
| conf = float(box.conf[0]) |
| label = CLASS_NAMES[cls_id] |
| |
| |
| if label in ["hardhat", "mask", "safety vest", "gloves", "safety shoes"]: |
| color = (0, 255, 0) |
| elif label in ["no-hardhat", "no-mask", "no-safety vest"]: |
| color = (0, 0, 255) |
| else: |
| color = (255, 0, 0) |
| |
| cv2.rectangle(img, (x1, y1), (x2, y2), color, 2) |
| text = f"{label} {conf:.2f}" |
| cv2.putText(img, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) |
| |
| cv2.imwrite(save_path, img) |
| return save_path |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--image", required=True, help="Path to input image") |
| parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") |
| parser.add_argument("--output", default="output.jpg", help="Output image path") |
| args = parser.parse_args() |
| |
| print("Loading model...") |
| result = detect(args.image, args.conf) |
| |
| print(f"Detected {len(result.boxes)} objects") |
| for box in result.boxes: |
| cls_id = int(box.cls[0]) |
| conf = float(box.conf[0]) |
| print(f" - {CLASS_NAMES[cls_id]}: {conf:.3f}") |
| |
| print("\n--- PPE Compliance Report ---") |
| report = check_compliance(result) |
| for ppe, info in report.items(): |
| print(f" {ppe}: {info['status']}") |
| |
| draw_detections(args.image, result, args.output) |
| print(f"\nAnnotated image saved to {args.output}") |
|
|