ppe-training-scripts / train_ppe_improved_v2.py
baskarmother's picture
Upload train_ppe_improved_v2.py
d877ac8 verified
#!/usr/bin/env python3
"""
Improved PPE Compliance Detection Training Script v2
Fixed: Added config='full' for keremberke dataset
Combines multiple datasets for better coverage:
1. 51ddhesh/PPE_Detection (~10K images, 6 PPE classes, YOLO format)
2. keremberke/construction-safety-object-detection (398 images, 17 classes incl. violations)
Trains YOLOv8s on combined data.
"""
import os
import sys
import zipfile
import shutil
from pathlib import Path
from huggingface_hub import hf_hub_download, HfApi
from datasets import load_dataset
from PIL import Image
import yaml
# ========== CONFIG ==========
HF_USERNAME = "baskarmother"
MODEL_ID = "yolov8s-ppe-construction-v2"
DATASET_DIR = Path("/app/combined_ppe_dataset")
EPOCHS = 150
IMG_SIZE = 640
BATCH = 16
DEVICE = "0"
# Unified class mapping
UNIFIED_CLASSES = [
"person",
"helmet",
"vest",
"mask",
"gloves",
"safety_shoe",
"goggles",
"no_helmet",
"no_mask",
"no_vest",
"head",
"barricade",
"dumpster",
"excavators",
"safety_net",
"dump_truck",
"truck",
"wheel_loader",
]
def download_ppe_dataset():
"""Download 51ddhesh/PPE_Detection ZIP and extract."""
print("[1/5] Downloading 51ddhesh/PPE_Detection dataset...")
zip_path = hf_hub_download(
repo_id="51ddhesh/PPE_Detection",
filename="PPE.zip",
repo_type="dataset",
cache_dir="/app/hf_cache",
local_dir="/app/downloads",
)
extract_dir = Path("/app/downloads/ppe_dataset")
extract_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall(extract_dir)
print(f" Extracted to {extract_dir}")
return extract_dir
def load_keremberke_dataset():
"""Load keremberke construction-safety-object-detection."""
print("[2/5] Loading keremberke/construction-safety-object-detection...")
ds = load_dataset("keremberke/construction-safety-object-detection", "full")
print(f" Splits: {list(ds.keys())}")
return ds
def convert_keremberke_to_yolo(ds, output_dir: Path):
"""Convert keremberke COCO-style dataset to YOLO format."""
print("[3/5] Converting keremberke dataset to YOLO format...")
class_names = ds["train"].features["objects"].feature["category"].names
print(f" Classes: {class_names}")
class_map = {
"person": 0,
"hardhat": 1,
"mask": 3,
"no-hardhat": 7,
"no-mask": 8,
"no-safety vest": 9,
"gloves": 4,
"safety shoes": 5,
"safety vest": 2,
"barricade": 11,
"dumpster": 12,
"excavators": 13,
"safety net": 14,
"dump truck": 15,
"mini-van": 0,
"truck": 16,
"wheel loader": 17,
}
for split in ["train", "valid", "test"]:
if split not in ds:
continue
images_dir = output_dir / split / "images"
labels_dir = output_dir / split / "labels"
images_dir.mkdir(parents=True, exist_ok=True)
labels_dir.mkdir(parents=True, exist_ok=True)
for i, example in enumerate(ds[split]):
img = example["image"]
img_filename = f"keremberke_{split}_{i:05d}.jpg"
img_path = images_dir / img_filename
img.save(img_path)
width, height = img.size
objects = example["objects"]
bboxes = objects["bbox"]
categories = objects["category"]
label_filename = img_filename.replace(".jpg", ".txt")
label_path = labels_dir / label_filename
with open(label_path, "w") as f:
for bbox, cat in zip(bboxes, categories):
class_name = class_names[cat]
if class_name not in class_map:
continue
unified_idx = class_map[class_name]
x, y, w, h = bbox
x_center = (x + w / 2) / width
y_center = (y + h / 2) / height
norm_w = w / width
norm_h = h / height
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
norm_w = max(0, min(1, norm_w))
norm_h = max(0, min(1, norm_h))
f.write(f"{unified_idx} {x_center:.6f} {y_center:.6f} {norm_w:.6f} {norm_h:.6f}\n")
print(f" Converted keremberke dataset to {output_dir}")
def merge_datasets(ppe_extract_dir: Path, keremberke_dir: Path, output_dir: Path):
"""Merge both datasets into unified YOLO structure."""
print("[4/5] Merging datasets...")
output_dir.mkdir(parents=True, exist_ok=True)
ppe_dir = None
for candidate in [ppe_extract_dir / "PPE", ppe_extract_dir / "ppe", ppe_extract_dir]:
if (candidate / "train" / "images").exists():
ppe_dir = candidate
break
if ppe_dir is None:
print(" ERROR: Could not find PPE dataset structure")
print(f" Contents: {list(ppe_extract_dir.iterdir())}")
sys.exit(1)
print(f" Found PPE dataset at: {ppe_dir}")
ppe_class_map = {
0: 2, # Vest
1: 5, # Safety Shoe
2: 3, # Mask
3: 1, # Helmet
4: 6, # Goggles
5: 4, # Gloves
}
for split in ["train", "valid", "test"]:
out_images = output_dir / split / "images"
out_labels = output_dir / split / "labels"
out_images.mkdir(parents=True, exist_ok=True)
out_labels.mkdir(parents=True, exist_ok=True)
ppe_images = ppe_dir / split / "images"
ppe_labels = ppe_dir / split / "labels"
if ppe_images.exists():
for img_file in sorted(ppe_images.iterdir()):
if img_file.suffix.lower() not in [".jpg", ".jpeg", ".png"]:
continue
shutil.copy2(img_file, out_images / f"ppe_{img_file.name}")
label_file = ppe_labels / f"{img_file.stem}.txt"
if label_file.exists():
with open(label_file) as f:
lines = f.readlines()
remapped = []
for line in lines:
parts = line.strip().split()
if len(parts) < 5:
continue
src_cls = int(parts[0])
if src_cls in ppe_class_map:
unified_cls = ppe_class_map[src_cls]
remapped.append(f"{unified_cls} {' '.join(parts[1:])}\n")
out_label = out_labels / f"ppe_{img_file.stem}.txt"
with open(out_label, "w") as f:
f.writelines(remapped)
k_images = keremberke_dir / split / "images"
k_labels = keremberke_dir / split / "labels"
if k_images.exists():
for img_file in sorted(k_images.iterdir()):
shutil.copy2(img_file, out_images / img_file.name)
for label_file in sorted(k_labels.iterdir()):
shutil.copy2(label_file, out_labels / label_file.name)
data_yaml = {
"path": str(output_dir.absolute()),
"train": "train/images",
"val": "valid/images",
"test": "test/images",
"names": {i: name for i, name in enumerate(UNIFIED_CLASSES)},
"nc": len(UNIFIED_CLASSES),
}
with open(output_dir / "data.yaml", "w") as f:
yaml.dump(data_yaml, f, default_flow_style=False)
print(f" Merged dataset at {output_dir}")
for split in ["train", "valid", "test"]:
img_count = len(list((output_dir / split / "images").glob("*")))
print(f" {split}: {img_count} images")
def train_model(data_yaml_path: Path):
print("[5/5] Training YOLOv8s...")
from ultralytics import YOLO
model = YOLO("yolov8s.pt")
results = model.train(
data=str(data_yaml_path),
epochs=EPOCHS,
imgsz=IMG_SIZE,
batch=BATCH,
device=DEVICE,
patience=30,
project="/app/runs",
name="ppe_improved",
exist_ok=True,
pretrained=True,
optimizer="SGD",
lr0=0.01,
lrf=0.01,
momentum=0.9,
weight_decay=0.0005,
augment=True,
mosaic=1.0,
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=5.0,
translate=0.1,
scale=0.5,
shear=2.0,
perspective=0.0,
flipud=0.0,
fliplr=0.5,
)
print(" Training complete!")
print(f" Best model: {results.best}")
return results
def push_to_hub(best_model_path: Path):
print("Pushing model to HuggingFace Hub...")
api = HfApi()
repo_id = f"{HF_USERNAME}/{MODEL_ID}"
try:
api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
except Exception as e:
print(f" Repo creation info: {e}")
api.upload_file(
path_or_fileobj=str(best_model_path),
path_in_repo="best.pt",
repo_id=repo_id,
repo_type="model",
)
readme = f"""---
license: cc-by-4.0
library_name: ultralytics
tags:
- object-detection
- ppe
- construction-safety
- yolov8
---
# {MODEL_ID}
Improved PPE Compliance Detection Model for Construction Sites (v2)
## Description
This is an improved YOLOv8s model trained on a combined dataset of:
- **51ddhesh/PPE_Detection** (~10K images, 6 PPE classes)
- **keremberke/construction-safety-object-detection** (398 images, violation classes)
## Classes ({len(UNIFIED_CLASSES)})
{chr(10).join(f"- {i}: {name}" for i, name in enumerate(UNIFIED_CLASSES))}
## Usage
```python
from ultralytics import YOLO
model = YOLO("hf://{repo_id}/best.pt")
results = model.predict("image.jpg")
```
## Training Details
- Base Model: YOLOv8s
- Epochs: {EPOCHS}
- Image Size: {IMG_SIZE}x{IMG_SIZE}
- Batch Size: {BATCH}
- Augmentations: Mosaic, HSV, scale, shear, flip
## Compliance Detection
The model detects both PPE presence AND absence:
- `no_helmet`, `no_mask`, `no_vest` = violation classes
- `helmet`, `mask`, `vest` = compliance classes
"""
api.upload_file(
path_or_fileobj=readme.encode(),
path_in_repo="README.md",
repo_id=repo_id,
repo_type="model",
)
print(f" Model pushed to https://huggingface.co/{repo_id}")
def main():
print("=" * 60)
print("IMPROVED PPE DETECTION TRAINING v2")
print("=" * 60)
ppe_dir = download_ppe_dataset()
keremberke_ds = load_keremberke_dataset()
keremberke_yolo_dir = Path("/app/keremberke_yolo")
convert_keremberke_to_yolo(keremberke_ds, keremberke_yolo_dir)
DATASET_DIR.mkdir(parents=True, exist_ok=True)
merge_datasets(ppe_dir, keremberke_yolo_dir, DATASET_DIR)
data_yaml = DATASET_DIR / "data.yaml"
results = train_model(data_yaml)
best_model = Path("/app/runs/ppe_improved/weights/best.pt")
if best_model.exists():
push_to_hub(best_model)
else:
print(f" WARNING: Best model not found at {best_model}")
for pt_file in Path("/app/runs").rglob("best.pt"):
print(f" Found: {pt_file}")
push_to_hub(pt_file)
break
print("=" * 60)
print("DONE!")
print("=" * 60)
if __name__ == "__main__":
main()