| """ |
| PPE Compliance Detection Model Training Script |
| Converts COCO-format dataset from HuggingFace to YOLO format and trains YOLOv8 |
| """ |
| import os |
| import sys |
| from pathlib import Path |
| from datasets import load_dataset |
| from PIL import Image |
| import yaml |
| from ultralytics import YOLO |
| from huggingface_hub import HfApi, create_repo |
| import shutil |
|
|
| |
| DATASET_NAME = "keremberke/construction-safety-object-detection" |
| DATASET_CONFIG = "full" |
| OUTPUT_DIR = Path("/app/ppe_dataset") |
| MODEL_SIZE = "yolov8n" |
| EPOCHS = 100 |
| IMGSZ = 640 |
| BATCH = 16 |
| HUB_MODEL_ID = "baskarmother/yolov8-ppe-construction" |
|
|
| CATEGORY_NAMES = [ |
| 'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask', |
| 'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net', |
| 'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader' |
| ] |
|
|
|
|
| def convert_coco_to_yolo(example): |
| """Convert COCO bbox [x, y, width, height] to YOLO format.""" |
| img_w = example['width'] |
| img_h = example['height'] |
| yolo_lines = [] |
| |
| for i in range(len(example['objects']['id'])): |
| cat = example['objects']['category'][i] |
| bbox = example['objects']['bbox'][i] |
| x, y, w, h = bbox |
| x_center = (x + w / 2) / img_w |
| y_center = (y + h / 2) / img_h |
| nw = w / img_w |
| nh = h / img_h |
| x_center = max(0, min(1, x_center)) |
| y_center = max(0, min(1, y_center)) |
| nw = max(0, min(1, nw)) |
| nh = max(0, min(1, nh)) |
| yolo_lines.append(f"{cat} {x_center:.6f} {y_center:.6f} {nw:.6f} {nh:.6f}") |
| |
| return "\n".join(yolo_lines) |
|
|
|
|
| def prepare_dataset(): |
| """Download and convert dataset to YOLO format.""" |
| print(f"Loading dataset: {DATASET_NAME} ({DATASET_CONFIG})") |
| ds = load_dataset(DATASET_NAME, name=DATASET_CONFIG, trust_remote_code=True) |
| |
| for split in ['train', 'validation', 'test']: |
| if split not in ds: |
| continue |
| img_dir = OUTPUT_DIR / 'images' / split.replace('validation', 'val') |
| lbl_dir = OUTPUT_DIR / 'labels' / split.replace('validation', 'val') |
| img_dir.mkdir(parents=True, exist_ok=True) |
| lbl_dir.mkdir(parents=True, exist_ok=True) |
| |
| print(f"Processing {split}: {len(ds[split])} examples") |
| for idx, example in enumerate(ds[split]): |
| img = example['image'] |
| img_name = f"{example['image_id']:06d}.jpg" |
| img_path = img_dir / img_name |
| img.save(img_path) |
| |
| label_content = convert_coco_to_yolo(example) |
| label_path = lbl_dir / img_name.replace('.jpg', '.txt') |
| label_path.write_text(label_content) |
| |
| data_yaml = { |
| 'path': str(OUTPUT_DIR), |
| 'train': 'images/train', |
| 'val': 'images/val', |
| 'test': 'images/test', |
| 'names': {i: name for i, name in enumerate(CATEGORY_NAMES)} |
| } |
| |
| yaml_path = OUTPUT_DIR / 'data.yaml' |
| with open(yaml_path, 'w') as f: |
| yaml.dump(data_yaml, f, default_flow_style=False, sort_keys=False) |
| |
| print(f"Dataset prepared at {OUTPUT_DIR}") |
| print(f"Categories: {len(CATEGORY_NAMES)}") |
| for i, name in enumerate(CATEGORY_NAMES): |
| print(f" {i}: {name}") |
| return yaml_path |
|
|
|
|
| def train_model(data_yaml_path): |
| """Train YOLOv8 model.""" |
| print(f"\nInitializing YOLO {MODEL_SIZE} model...") |
| model = YOLO(f"{MODEL_SIZE}.pt") |
| |
| print(f"Starting training: epochs={EPOCHS}, imgsz={IMGSZ}, batch={BATCH}") |
| results = model.train( |
| data=str(data_yaml_path), |
| epochs=EPOCHS, |
| imgsz=IMGSZ, |
| batch=BATCH, |
| device=0, |
| patience=30, |
| optimizer='SGD', |
| lr0=0.01, |
| lrf=0.01, |
| momentum=0.9, |
| weight_decay=0.0005, |
| augment=True, |
| mosaic=1.0, |
| mixup=0.0, |
| project='/app/runs', |
| name='ppe_training', |
| exist_ok=True, |
| verbose=True, |
| ) |
| |
| return model, results |
|
|
|
|
| def evaluate_model(model): |
| """Evaluate on test set.""" |
| print("\nEvaluating on test set...") |
| metrics = model.val(data=str(OUTPUT_DIR / 'data.yaml'), split='test') |
| print(f"Test mAP@50: {metrics.box.map50:.4f}") |
| print(f"Test mAP@50:95: {metrics.box.map:.4f}") |
| return metrics |
|
|
|
|
| def push_to_hub(model, hub_model_id): |
| """Push model to HuggingFace Hub.""" |
| print(f"\nPushing to HuggingFace Hub: {hub_model_id}") |
| |
| api = HfApi() |
| try: |
| create_repo(hub_model_id, repo_type="model", exist_ok=True) |
| except Exception as e: |
| print(f"Repo creation note: {e}") |
| |
| best_pt = Path('/app/runs/ppe_training/weights/best.pt') |
| if not best_pt.exists(): |
| print("WARNING: best.pt not found, checking for last.pt") |
| best_pt = Path('/app/runs/ppe_training/weights/last.pt') |
| |
| if best_pt.exists(): |
| api.upload_file( |
| path_or_fileobj=str(best_pt), |
| path_in_repo="best.pt", |
| repo_id=hub_model_id, |
| repo_type="model", |
| ) |
| print(f"Model uploaded to https://huggingface.co/{hub_model_id}") |
| else: |
| print("ERROR: No weights file found!") |
| return False |
| |
| readme = f"""--- |
| tags: |
| - ultralytics |
| - vision |
| - object-detection |
| - yolov8 |
| - ppe |
| - construction-safety |
| - safety |
| license: mit |
| --- |
| |
| # YOLOv8 PPE Compliance Detection for Construction Sites |
| |
| This model detects Personal Protective Equipment (PPE) compliance on construction sites. |
| |
| ## Classes ({len(CATEGORY_NAMES)} categories) |
| |
| {chr(10).join([f"- **{i}**: {name}" for i, name in enumerate(CATEGORY_NAMES)])} |
| |
| ## Training Details |
| |
| - **Base Model**: {MODEL_SIZE} |
| - **Dataset**: [keremberke/construction-safety-object-detection](https://huggingface.co/datasets/keremberke/construction-safety-object-detection) |
| - **Image Size**: {IMGSZ}x{IMGSZ} |
| - **Epochs**: {EPOCHS} |
| - **Optimizer**: SGD (lr=0.01, momentum=0.9) |
| |
| ## Usage |
| |
| ```python |
| from ultralytics import YOLO |
| from huggingface_hub import hf_hub_download |
| |
| model = YOLO(hf_hub_download("{hub_model_id}", "best.pt")) |
| results = model("your_image.jpg") |
| results[0].plot() |
| ``` |
| """ |
| api.upload_file( |
| path_or_fileobj=readme.encode(), |
| path_in_repo="README.md", |
| repo_id=hub_model_id, |
| repo_type="model", |
| ) |
| |
| return True |
|
|
|
|
| def main(): |
| hub_model_id = os.environ.get("HUB_MODEL_ID", HUB_MODEL_ID) |
| |
| print("=" * 60) |
| print("PPE Compliance Detection - Model Training") |
| print("=" * 60) |
| |
| data_yaml_path = prepare_dataset() |
| model, results = train_model(data_yaml_path) |
| metrics = evaluate_model(model) |
| |
| if hub_model_id: |
| success = push_to_hub(model, hub_model_id) |
| if success: |
| print(f"\nModel successfully published to https://huggingface.co/{hub_model_id}") |
| |
| print("\nTraining complete!") |
| return model, metrics |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|