baskarmother commited on
Commit
acff799
·
verified ·
1 Parent(s): e31e7e4

Add PPE training script

Browse files
Files changed (1) hide show
  1. train_ppe.py +234 -0
train_ppe.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PPE Compliance Detection Model Training Script
3
+ Converts COCO-format dataset from HuggingFace to YOLO format and trains YOLOv8
4
+ """
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ from datasets import load_dataset
9
+ from PIL import Image
10
+ import yaml
11
+ from ultralytics import YOLO
12
+ from huggingface_hub import HfApi, create_repo
13
+ import shutil
14
+
15
+ # Configuration
16
+ DATASET_NAME = "keremberke/construction-safety-object-detection"
17
+ DATASET_CONFIG = "full"
18
+ OUTPUT_DIR = Path("/app/ppe_dataset")
19
+ MODEL_SIZE = "yolov8n"
20
+ EPOCHS = 100
21
+ IMGSZ = 640
22
+ BATCH = 16
23
+ HUB_MODEL_ID = "baskarmother/yolov8-ppe-construction"
24
+
25
+ CATEGORY_NAMES = [
26
+ 'barricade', 'dumpster', 'excavators', 'gloves', 'hardhat', 'mask',
27
+ 'no-hardhat', 'no-mask', 'no-safety vest', 'person', 'safety net',
28
+ 'safety shoes', 'safety vest', 'dump truck', 'mini-van', 'truck', 'wheel loader'
29
+ ]
30
+
31
+
32
+ def convert_coco_to_yolo(example):
33
+ """Convert COCO bbox [x, y, width, height] to YOLO format."""
34
+ img_w = example['width']
35
+ img_h = example['height']
36
+ yolo_lines = []
37
+
38
+ for i in range(len(example['objects']['id'])):
39
+ cat = example['objects']['category'][i]
40
+ bbox = example['objects']['bbox'][i]
41
+ x, y, w, h = bbox
42
+ x_center = (x + w / 2) / img_w
43
+ y_center = (y + h / 2) / img_h
44
+ nw = w / img_w
45
+ nh = h / img_h
46
+ x_center = max(0, min(1, x_center))
47
+ y_center = max(0, min(1, y_center))
48
+ nw = max(0, min(1, nw))
49
+ nh = max(0, min(1, nh))
50
+ yolo_lines.append(f"{cat} {x_center:.6f} {y_center:.6f} {nw:.6f} {nh:.6f}")
51
+
52
+ return "\n".join(yolo_lines)
53
+
54
+
55
+ def prepare_dataset():
56
+ """Download and convert dataset to YOLO format."""
57
+ print(f"Loading dataset: {DATASET_NAME} ({DATASET_CONFIG})")
58
+ ds = load_dataset(DATASET_NAME, name=DATASET_CONFIG, trust_remote_code=True)
59
+
60
+ for split in ['train', 'validation', 'test']:
61
+ if split not in ds:
62
+ continue
63
+ img_dir = OUTPUT_DIR / 'images' / split.replace('validation', 'val')
64
+ lbl_dir = OUTPUT_DIR / 'labels' / split.replace('validation', 'val')
65
+ img_dir.mkdir(parents=True, exist_ok=True)
66
+ lbl_dir.mkdir(parents=True, exist_ok=True)
67
+
68
+ print(f"Processing {split}: {len(ds[split])} examples")
69
+ for idx, example in enumerate(ds[split]):
70
+ img = example['image']
71
+ img_name = f"{example['image_id']:06d}.jpg"
72
+ img_path = img_dir / img_name
73
+ img.save(img_path)
74
+
75
+ label_content = convert_coco_to_yolo(example)
76
+ label_path = lbl_dir / img_name.replace('.jpg', '.txt')
77
+ label_path.write_text(label_content)
78
+
79
+ data_yaml = {
80
+ 'path': str(OUTPUT_DIR),
81
+ 'train': 'images/train',
82
+ 'val': 'images/val',
83
+ 'test': 'images/test',
84
+ 'names': {i: name for i, name in enumerate(CATEGORY_NAMES)}
85
+ }
86
+
87
+ yaml_path = OUTPUT_DIR / 'data.yaml'
88
+ with open(yaml_path, 'w') as f:
89
+ yaml.dump(data_yaml, f, default_flow_style=False, sort_keys=False)
90
+
91
+ print(f"Dataset prepared at {OUTPUT_DIR}")
92
+ print(f"Categories: {len(CATEGORY_NAMES)}")
93
+ for i, name in enumerate(CATEGORY_NAMES):
94
+ print(f" {i}: {name}")
95
+ return yaml_path
96
+
97
+
98
+ def train_model(data_yaml_path):
99
+ """Train YOLOv8 model."""
100
+ print(f"\nInitializing YOLO {MODEL_SIZE} model...")
101
+ model = YOLO(f"{MODEL_SIZE}.pt")
102
+
103
+ print(f"Starting training: epochs={EPOCHS}, imgsz={IMGSZ}, batch={BATCH}")
104
+ results = model.train(
105
+ data=str(data_yaml_path),
106
+ epochs=EPOCHS,
107
+ imgsz=IMGSZ,
108
+ batch=BATCH,
109
+ device=0,
110
+ patience=30,
111
+ optimizer='SGD',
112
+ lr0=0.01,
113
+ lrf=0.01,
114
+ momentum=0.9,
115
+ weight_decay=0.0005,
116
+ augment=True,
117
+ mosaic=1.0,
118
+ mixup=0.0,
119
+ project='/app/runs',
120
+ name='ppe_training',
121
+ exist_ok=True,
122
+ verbose=True,
123
+ )
124
+
125
+ return model, results
126
+
127
+
128
+ def evaluate_model(model):
129
+ """Evaluate on test set."""
130
+ print("\nEvaluating on test set...")
131
+ metrics = model.val(data=str(OUTPUT_DIR / 'data.yaml'), split='test')
132
+ print(f"Test mAP@50: {metrics.box.map50:.4f}")
133
+ print(f"Test mAP@50:95: {metrics.box.map:.4f}")
134
+ return metrics
135
+
136
+
137
+ def push_to_hub(model, hub_model_id):
138
+ """Push model to HuggingFace Hub."""
139
+ print(f"\nPushing to HuggingFace Hub: {hub_model_id}")
140
+
141
+ api = HfApi()
142
+ try:
143
+ create_repo(hub_model_id, repo_type="model", exist_ok=True)
144
+ except Exception as e:
145
+ print(f"Repo creation note: {e}")
146
+
147
+ best_pt = Path('/app/runs/ppe_training/weights/best.pt')
148
+ if not best_pt.exists():
149
+ print("WARNING: best.pt not found, checking for last.pt")
150
+ best_pt = Path('/app/runs/ppe_training/weights/last.pt')
151
+
152
+ if best_pt.exists():
153
+ api.upload_file(
154
+ path_or_fileobj=str(best_pt),
155
+ path_in_repo="best.pt",
156
+ repo_id=hub_model_id,
157
+ repo_type="model",
158
+ )
159
+ print(f"Model uploaded to https://huggingface.co/{hub_model_id}")
160
+ else:
161
+ print("ERROR: No weights file found!")
162
+ return False
163
+
164
+ readme = f"""---
165
+ tags:
166
+ - ultralytics
167
+ - vision
168
+ - object-detection
169
+ - yolov8
170
+ - ppe
171
+ - construction-safety
172
+ - safety
173
+ license: mit
174
+ ---
175
+
176
+ # YOLOv8 PPE Compliance Detection for Construction Sites
177
+
178
+ This model detects Personal Protective Equipment (PPE) compliance on construction sites.
179
+
180
+ ## Classes ({len(CATEGORY_NAMES)} categories)
181
+
182
+ {chr(10).join([f"- **{i}**: {name}" for i, name in enumerate(CATEGORY_NAMES)])}
183
+
184
+ ## Training Details
185
+
186
+ - **Base Model**: {MODEL_SIZE}
187
+ - **Dataset**: [keremberke/construction-safety-object-detection](https://huggingface.co/datasets/keremberke/construction-safety-object-detection)
188
+ - **Image Size**: {IMGSZ}x{IMGSZ}
189
+ - **Epochs**: {EPOCHS}
190
+ - **Optimizer**: SGD (lr=0.01, momentum=0.9)
191
+
192
+ ## Usage
193
+
194
+ ```python
195
+ from ultralytics import YOLO
196
+ from huggingface_hub import hf_hub_download
197
+
198
+ model = YOLO(hf_hub_download("{hub_model_id}", "best.pt"))
199
+ results = model("your_image.jpg")
200
+ results[0].plot()
201
+ ```
202
+ """
203
+ api.upload_file(
204
+ path_or_fileobj=readme.encode(),
205
+ path_in_repo="README.md",
206
+ repo_id=hub_model_id,
207
+ repo_type="model",
208
+ )
209
+
210
+ return True
211
+
212
+
213
+ def main():
214
+ hub_model_id = os.environ.get("HUB_MODEL_ID", HUB_MODEL_ID)
215
+
216
+ print("=" * 60)
217
+ print("PPE Compliance Detection - Model Training")
218
+ print("=" * 60)
219
+
220
+ data_yaml_path = prepare_dataset()
221
+ model, results = train_model(data_yaml_path)
222
+ metrics = evaluate_model(model)
223
+
224
+ if hub_model_id:
225
+ success = push_to_hub(model, hub_model_id)
226
+ if success:
227
+ print(f"\nModel successfully published to https://huggingface.co/{hub_model_id}")
228
+
229
+ print("\nTraining complete!")
230
+ return model, metrics
231
+
232
+
233
+ if __name__ == "__main__":
234
+ main()