ppe-training-scripts / train_ppe.py
baskarmother's picture
Add PPE training script
acff799 verified
"""
PPE Compliance Detection Model Training Script
Converts COCO-format dataset from HuggingFace to YOLO format and trains YOLOv8
"""
import os
import sys
from pathlib import Path
from datasets import load_dataset
from PIL import Image
import yaml
from ultralytics import YOLO
from huggingface_hub import HfApi, create_repo
import shutil
# Configuration
DATASET_NAME = "keremberke/construction-safety-object-detection"
DATASET_CONFIG = "full"
OUTPUT_DIR = Path("/app/ppe_dataset")
MODEL_SIZE = "yolov8n"
EPOCHS = 100
IMGSZ = 640
BATCH = 16
HUB_MODEL_ID = "baskarmother/yolov8-ppe-construction"
CATEGORY_NAMES = [
'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask',
'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net',
'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader'
]
def convert_coco_to_yolo(example):
"""Convert COCO bbox [x, y, width, height] to YOLO format."""
img_w = example['width']
img_h = example['height']
yolo_lines = []
for i in range(len(example['objects']['id'])):
cat = example['objects']['category'][i]
bbox = example['objects']['bbox'][i]
x, y, w, h = bbox
x_center = (x + w / 2) / img_w
y_center = (y + h / 2) / img_h
nw = w / img_w
nh = h / img_h
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
nw = max(0, min(1, nw))
nh = max(0, min(1, nh))
yolo_lines.append(f"{cat} {x_center:.6f} {y_center:.6f} {nw:.6f} {nh:.6f}")
return "\n".join(yolo_lines)
def prepare_dataset():
"""Download and convert dataset to YOLO format."""
print(f"Loading dataset: {DATASET_NAME} ({DATASET_CONFIG})")
ds = load_dataset(DATASET_NAME, name=DATASET_CONFIG, trust_remote_code=True)
for split in ['train', 'validation', 'test']:
if split not in ds:
continue
img_dir = OUTPUT_DIR / 'images' / split.replace('validation', 'val')
lbl_dir = OUTPUT_DIR / 'labels' / split.replace('validation', 'val')
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
print(f"Processing {split}: {len(ds[split])} examples")
for idx, example in enumerate(ds[split]):
img = example['image']
img_name = f"{example['image_id']:06d}.jpg"
img_path = img_dir / img_name
img.save(img_path)
label_content = convert_coco_to_yolo(example)
label_path = lbl_dir / img_name.replace('.jpg', '.txt')
label_path.write_text(label_content)
data_yaml = {
'path': str(OUTPUT_DIR),
'train': 'images/train',
'val': 'images/val',
'test': 'images/test',
'names': {i: name for i, name in enumerate(CATEGORY_NAMES)}
}
yaml_path = OUTPUT_DIR / 'data.yaml'
with open(yaml_path, 'w') as f:
yaml.dump(data_yaml, f, default_flow_style=False, sort_keys=False)
print(f"Dataset prepared at {OUTPUT_DIR}")
print(f"Categories: {len(CATEGORY_NAMES)}")
for i, name in enumerate(CATEGORY_NAMES):
print(f" {i}: {name}")
return yaml_path
def train_model(data_yaml_path):
"""Train YOLOv8 model."""
print(f"\nInitializing YOLO {MODEL_SIZE} model...")
model = YOLO(f"{MODEL_SIZE}.pt")
print(f"Starting training: epochs={EPOCHS}, imgsz={IMGSZ}, batch={BATCH}")
results = model.train(
data=str(data_yaml_path),
epochs=EPOCHS,
imgsz=IMGSZ,
batch=BATCH,
device=0,
patience=30,
optimizer='SGD',
lr0=0.01,
lrf=0.01,
momentum=0.9,
weight_decay=0.0005,
augment=True,
mosaic=1.0,
mixup=0.0,
project='/app/runs',
name='ppe_training',
exist_ok=True,
verbose=True,
)
return model, results
def evaluate_model(model):
"""Evaluate on test set."""
print("\nEvaluating on test set...")
metrics = model.val(data=str(OUTPUT_DIR / 'data.yaml'), split='test')
print(f"Test mAP@50: {metrics.box.map50:.4f}")
print(f"Test mAP@50:95: {metrics.box.map:.4f}")
return metrics
def push_to_hub(model, hub_model_id):
"""Push model to HuggingFace Hub."""
print(f"\nPushing to HuggingFace Hub: {hub_model_id}")
api = HfApi()
try:
create_repo(hub_model_id, repo_type="model", exist_ok=True)
except Exception as e:
print(f"Repo creation note: {e}")
best_pt = Path('/app/runs/ppe_training/weights/best.pt')
if not best_pt.exists():
print("WARNING: best.pt not found, checking for last.pt")
best_pt = Path('/app/runs/ppe_training/weights/last.pt')
if best_pt.exists():
api.upload_file(
path_or_fileobj=str(best_pt),
path_in_repo="best.pt",
repo_id=hub_model_id,
repo_type="model",
)
print(f"Model uploaded to https://huggingface.co/{hub_model_id}")
else:
print("ERROR: No weights file found!")
return False
readme = f"""---
tags:
- ultralytics
- vision
- object-detection
- yolov8
- ppe
- construction-safety
- safety
license: mit
---
# YOLOv8 PPE Compliance Detection for Construction Sites
This model detects Personal Protective Equipment (PPE) compliance on construction sites.
## Classes ({len(CATEGORY_NAMES)} categories)
{chr(10).join([f"- **{i}**: {name}" for i, name in enumerate(CATEGORY_NAMES)])}
## Training Details
- **Base Model**: {MODEL_SIZE}
- **Dataset**: [keremberke/construction-safety-object-detection](https://huggingface.co/datasets/keremberke/construction-safety-object-detection)
- **Image Size**: {IMGSZ}x{IMGSZ}
- **Epochs**: {EPOCHS}
- **Optimizer**: SGD (lr=0.01, momentum=0.9)
## Usage
```python
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
model = YOLO(hf_hub_download("{hub_model_id}", "best.pt"))
results = model("your_image.jpg")
results[0].plot()
```
"""
api.upload_file(
path_or_fileobj=readme.encode(),
path_in_repo="README.md",
repo_id=hub_model_id,
repo_type="model",
)
return True
def main():
hub_model_id = os.environ.get("HUB_MODEL_ID", HUB_MODEL_ID)
print("=" * 60)
print("PPE Compliance Detection - Model Training")
print("=" * 60)
data_yaml_path = prepare_dataset()
model, results = train_model(data_yaml_path)
metrics = evaluate_model(model)
if hub_model_id:
success = push_to_hub(model, hub_model_id)
if success:
print(f"\nModel successfully published to https://huggingface.co/{hub_model_id}")
print("\nTraining complete!")
return model, metrics
if __name__ == "__main__":
main()