ppe-training-scripts / inference.py
baskarmother's picture
Add inference script for PPE detection
deaa21d verified
raw
history blame
3.76 kB
"""
PPE Compliance Detection - Inference Script
Usage with trained YOLOv8 model
"""
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from PIL import Image
import cv2
import numpy as np
MODEL_ID = "baskarmother/yolov8-ppe-construction"
# PPE compliance pairs: (required_ppe, violation_class)
PPE_PAIRS = {
"hardhat": ("hardhat", "no-hardhat"),
"mask": ("mask", "no-mask"),
"safety_vest": ("safety vest", "no-safety vest"),
}
CLASS_NAMES = [
'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask',
'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net',
'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader'
]
def load_model():
"""Load model from HuggingFace Hub."""
weights_path = hf_hub_download(MODEL_ID, "best.pt")
model = YOLO(weights_path)
return model
def detect(image_path, conf_threshold=0.25):
"""Run detection on an image."""
model = load_model()
results = model(image_path, conf=conf_threshold)
return results[0]
def check_compliance(result):
"""Check PPE compliance from detection results."""
boxes = result.boxes
detected_classes = set()
for cls in boxes.cls:
detected_classes.add(CLASS_NAMES[int(cls)])
compliance_report = {}
for ppe_name, (required, violation) in PPE_PAIRS.items():
has_required = required in detected_classes
has_violation = violation in detected_classes
if has_violation:
status = "VIOLATION"
elif has_required:
status = "COMPLIANT"
else:
status = "NOT DETECTED"
compliance_report[ppe_name] = {
"status": status,
"required_detected": has_required,
"violation_detected": has_violation
}
return compliance_report
def draw_detections(image_path, result, save_path="output.jpg"):
"""Draw bounding boxes and labels on image."""
img = cv2.imread(image_path)
boxes = result.boxes
for box in boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
cls_id = int(box.cls[0])
conf = float(box.conf[0])
label = CLASS_NAMES[cls_id]
# Color: green for compliant PPE, red for violations, blue for others
if label in ["hardhat", "mask", "safety vest", "gloves", "safety shoes"]:
color = (0, 255, 0)
elif label in ["no-hardhat", "no-mask", "no-safety vest"]:
color = (0, 0, 255)
else:
color = (255, 0, 0)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
text = f"{label} {conf:.2f}"
cv2.putText(img, text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
cv2.imwrite(save_path, img)
return save_path
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--image", required=True, help="Path to input image")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
parser.add_argument("--output", default="output.jpg", help="Output image path")
args = parser.parse_args()
print("Loading model...")
result = detect(args.image, args.conf)
print(f"Detected {len(result.boxes)} objects")
for box in result.boxes:
cls_id = int(box.cls[0])
conf = float(box.conf[0])
print(f" - {CLASS_NAMES[cls_id]}: {conf:.3f}")
print("\n--- PPE Compliance Report ---")
report = check_compliance(result)
for ppe, info in report.items():
print(f" {ppe}: {info['status']}")
draw_detections(args.image, result, args.output)
print(f"\nAnnotated image saved to {args.output}")