baskarmother commited on
Commit
acd2897
·
verified ·
1 Parent(s): 87e2c95

Upload train_ppe_improved.py

Browse files
Files changed (1) hide show
  1. train_ppe_improved.py +378 -0
train_ppe_improved.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Improved PPE Compliance Detection Training Script
4
+ Combines multiple datasets for better coverage:
5
+ 1. 51ddhesh/PPE_Detection (~10K images, 6 PPE classes, YOLO format)
6
+ 2. keremberke/construction-safety-object-detection (398 images, 17 classes incl. violations)
7
+
8
+ Trains YOLOv8s on combined data.
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import zipfile
14
+ import shutil
15
+ from pathlib import Path
16
+ from huggingface_hub import hf_hub_download, HfApi
17
+ from datasets import load_dataset
18
+ from PIL import Image
19
+ import yaml
20
+
21
+ # ========== CONFIG ==========
22
+ HF_USERNAME = "baskarmother"
23
+ MODEL_ID = "yolov8s-ppe-construction-v2"
24
+ DATASET_DIR = Path("/app/combined_ppe_dataset")
25
+ EPOCHS = 150
26
+ IMG_SIZE = 640
27
+ BATCH = 16
28
+ DEVICE = "0"
29
+
30
+ # Unified class mapping
31
+ UNIFIED_CLASSES = [
32
+ "person",
33
+ "helmet",
34
+ "vest",
35
+ "mask",
36
+ "gloves",
37
+ "safety_shoe",
38
+ "goggles",
39
+ "no_helmet",
40
+ "no_mask",
41
+ "no_vest",
42
+ "head",
43
+ "barricade",
44
+ "dumpster",
45
+ "excavators",
46
+ "safety_net",
47
+ "dump_truck",
48
+ "truck",
49
+ "wheel_loader",
50
+ ]
51
+
52
+
53
+ def download_ppe_dataset():
54
+ """Download 51ddhesh/PPE_Detection ZIP and extract."""
55
+ print("[1/5] Downloading 51ddhesh/PPE_Detection dataset...")
56
+ zip_path = hf_hub_download(
57
+ repo_id="51ddhesh/PPE_Detection",
58
+ filename="PPE.zip",
59
+ repo_type="dataset",
60
+ cache_dir="/app/hf_cache",
61
+ local_dir="/app/downloads",
62
+ local_dir_use_symlinks=False,
63
+ )
64
+ extract_dir = Path("/app/downloads/ppe_dataset")
65
+ extract_dir.mkdir(parents=True, exist_ok=True)
66
+ with zipfile.ZipFile(zip_path, 'r') as zf:
67
+ zf.extractall(extract_dir)
68
+ print(f" Extracted to {extract_dir}")
69
+ return extract_dir
70
+
71
+
72
+ def load_keremberke_dataset():
73
+ """Load keremberke construction-safety-object-detection."""
74
+ print("[2/5] Loading keremberke/construction-safety-object-detection...")
75
+ ds = load_dataset("keremberke/construction-safety-object-detection")
76
+ print(f" Splits: {list(ds.keys())}")
77
+ return ds
78
+
79
+
80
+ def convert_keremberke_to_yolo(ds, output_dir: Path):
81
+ """Convert keremberke COCO-style dataset to YOLO format."""
82
+ print("[3/5] Converting keremberke dataset to YOLO format...")
83
+ class_names = ds["train"].features["objects"].feature["category"].names
84
+ print(f" Classes: {class_names}")
85
+
86
+ class_map = {
87
+ "person": 0,
88
+ "hardhat": 1,
89
+ "mask": 3,
90
+ "no-hardhat": 7,
91
+ "no-mask": 8,
92
+ "no-safety vest": 9,
93
+ "gloves": 4,
94
+ "safety shoes": 5,
95
+ "safety vest": 2,
96
+ "barricade": 11,
97
+ "dumpster": 12,
98
+ "excavators": 13,
99
+ "safety net": 14,
100
+ "dump truck": 15,
101
+ "mini-van": 0,
102
+ "truck": 16,
103
+ "wheel loader": 17,
104
+ }
105
+
106
+ for split in ["train", "valid", "test"]:
107
+ if split not in ds:
108
+ continue
109
+ images_dir = output_dir / split / "images"
110
+ labels_dir = output_dir / split / "labels"
111
+ images_dir.mkdir(parents=True, exist_ok=True)
112
+ labels_dir.mkdir(parents=True, exist_ok=True)
113
+
114
+ for i, example in enumerate(ds[split]):
115
+ img = example["image"]
116
+ img_filename = f"keremberke_{split}_{i:05d}.jpg"
117
+ img_path = images_dir / img_filename
118
+ img.save(img_path)
119
+
120
+ width, height = img.size
121
+ objects = example["objects"]
122
+ bboxes = objects["bbox"]
123
+ categories = objects["category"]
124
+
125
+ label_filename = img_filename.replace(".jpg", ".txt")
126
+ label_path = labels_dir / label_filename
127
+
128
+ with open(label_path, "w") as f:
129
+ for bbox, cat in zip(bboxes, categories):
130
+ class_name = class_names[cat]
131
+ if class_name not in class_map:
132
+ continue
133
+ unified_idx = class_map[class_name]
134
+
135
+ x, y, w, h = bbox
136
+ x_center = (x + w / 2) / width
137
+ y_center = (y + h / 2) / height
138
+ norm_w = w / width
139
+ norm_h = h / height
140
+
141
+ x_center = max(0, min(1, x_center))
142
+ y_center = max(0, min(1, y_center))
143
+ norm_w = max(0, min(1, norm_w))
144
+ norm_h = max(0, min(1, norm_h))
145
+
146
+ f.write(f"{unified_idx} {x_center:.6f} {y_center:.6f} {norm_w:.6f} {norm_h:.6f}\n")
147
+
148
+ print(f" Converted keremberke dataset to {output_dir}")
149
+
150
+
151
+ def merge_datasets(ppe_extract_dir: Path, keremberke_dir: Path, output_dir: Path):
152
+ """Merge both datasets into unified YOLO structure."""
153
+ print("[4/5] Merging datasets...")
154
+ output_dir.mkdir(parents=True, exist_ok=True)
155
+
156
+ ppe_dir = None
157
+ for candidate in [ppe_extract_dir / "PPE", ppe_extract_dir / "ppe", ppe_extract_dir]:
158
+ if (candidate / "train" / "images").exists():
159
+ ppe_dir = candidate
160
+ break
161
+
162
+ if ppe_dir is None:
163
+ print(" ERROR: Could not find PPE dataset structure")
164
+ print(f" Contents: {list(ppe_extract_dir.iterdir())}")
165
+ sys.exit(1)
166
+
167
+ print(f" Found PPE dataset at: {ppe_dir}")
168
+
169
+ ppe_class_map = {
170
+ 0: 2, # Vest
171
+ 1: 5, # Safety Shoe
172
+ 2: 3, # Mask
173
+ 3: 1, # Helmet
174
+ 4: 6, # Goggles
175
+ 5: 4, # Gloves
176
+ }
177
+
178
+ for split in ["train", "valid", "test"]:
179
+ out_images = output_dir / split / "images"
180
+ out_labels = output_dir / split / "labels"
181
+ out_images.mkdir(parents=True, exist_ok=True)
182
+ out_labels.mkdir(parents=True, exist_ok=True)
183
+
184
+ ppe_images = ppe_dir / split / "images"
185
+ ppe_labels = ppe_dir / split / "labels"
186
+
187
+ if ppe_images.exists():
188
+ for img_file in sorted(ppe_images.iterdir()):
189
+ if img_file.suffix.lower() not in [".jpg", ".jpeg", ".png"]:
190
+ continue
191
+ shutil.copy2(img_file, out_images / f"ppe_{img_file.name}")
192
+
193
+ label_file = ppe_labels / f"{img_file.stem}.txt"
194
+ if label_file.exists():
195
+ with open(label_file) as f:
196
+ lines = f.readlines()
197
+ remapped = []
198
+ for line in lines:
199
+ parts = line.strip().split()
200
+ if len(parts) < 5:
201
+ continue
202
+ src_cls = int(parts[0])
203
+ if src_cls in ppe_class_map:
204
+ unified_cls = ppe_class_map[src_cls]
205
+ remapped.append(f"{unified_cls} {' '.join(parts[1:])}\n")
206
+
207
+ out_label = out_labels / f"ppe_{img_file.stem}.txt"
208
+ with open(out_label, "w") as f:
209
+ f.writelines(remapped)
210
+
211
+ k_images = keremberke_dir / split / "images"
212
+ k_labels = keremberke_dir / split / "labels"
213
+
214
+ if k_images.exists():
215
+ for img_file in sorted(k_images.iterdir()):
216
+ shutil.copy2(img_file, out_images / img_file.name)
217
+ for label_file in sorted(k_labels.iterdir()):
218
+ shutil.copy2(label_file, out_labels / label_file.name)
219
+
220
+ data_yaml = {
221
+ "path": str(output_dir.absolute()),
222
+ "train": "train/images",
223
+ "val": "valid/images",
224
+ "test": "test/images",
225
+ "names": {i: name for i, name in enumerate(UNIFIED_CLASSES)},
226
+ "nc": len(UNIFIED_CLASSES),
227
+ }
228
+
229
+ with open(output_dir / "data.yaml", "w") as f:
230
+ yaml.dump(data_yaml, f, default_flow_style=False)
231
+
232
+ print(f" Merged dataset at {output_dir}")
233
+ for split in ["train", "valid", "test"]:
234
+ img_count = len(list((output_dir / split / "images").glob("*")))
235
+ print(f" {split}: {img_count} images")
236
+
237
+
238
+ def train_model(data_yaml_path: Path):
239
+ print("[5/5] Training YOLOv8s...")
240
+ from ultralytics import YOLO
241
+
242
+ model = YOLO("yolov8s.pt")
243
+
244
+ results = model.train(
245
+ data=str(data_yaml_path),
246
+ epochs=EPOCHS,
247
+ imgsz=IMG_SIZE,
248
+ batch=BATCH,
249
+ device=DEVICE,
250
+ patience=30,
251
+ project="/app/runs",
252
+ name="ppe_improved",
253
+ exist_ok=True,
254
+ pretrained=True,
255
+ optimizer="SGD",
256
+ lr0=0.01,
257
+ lrf=0.01,
258
+ momentum=0.9,
259
+ weight_decay=0.0005,
260
+ augment=True,
261
+ mosaic=1.0,
262
+ hsv_h=0.015,
263
+ hsv_s=0.7,
264
+ hsv_v=0.4,
265
+ degrees=5.0,
266
+ translate=0.1,
267
+ scale=0.5,
268
+ shear=2.0,
269
+ perspective=0.0,
270
+ flipud=0.0,
271
+ fliplr=0.5,
272
+ )
273
+
274
+ print(" Training complete!")
275
+ print(f" Best model: {results.best}")
276
+ return results
277
+
278
+
279
+ def push_to_hub(best_model_path: Path):
280
+ print("Pushing model to HuggingFace Hub...")
281
+ api = HfApi()
282
+ repo_id = f"{HF_USERNAME}/{MODEL_ID}"
283
+
284
+ try:
285
+ api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
286
+ except Exception as e:
287
+ print(f" Repo creation info: {e}")
288
+
289
+ api.upload_file(
290
+ path_or_fileobj=str(best_model_path),
291
+ path_in_repo="best.pt",
292
+ repo_id=repo_id,
293
+ repo_type="model",
294
+ )
295
+
296
+ readme = f"""---
297
+ license: cc-by-4.0
298
+ library_name: ultralytics
299
+ tags:
300
+ - object-detection
301
+ - ppe
302
+ - construction-safety
303
+ - yolov8
304
+ ---
305
+
306
+ # {MODEL_ID}
307
+
308
+ Improved PPE Compliance Detection Model for Construction Sites (v2)
309
+
310
+ ## Description
311
+ This is an improved YOLOv8s model trained on a combined dataset of:
312
+ - **51ddhesh/PPE_Detection** (~10K images, 6 PPE classes)
313
+ - **keremberke/construction-safety-object-detection** (398 images, violation classes)
314
+
315
+ ## Classes ({len(UNIFIED_CLASSES)})
316
+ {chr(10).join(f"- {i}: {name}" for i, name in enumerate(UNIFIED_CLASSES))}
317
+
318
+ ## Usage
319
+ ```python
320
+ from ultralytics import YOLO
321
+ model = YOLO("hf://{repo_id}/best.pt")
322
+ results = model.predict("image.jpg")
323
+ ```
324
+
325
+ ## Training Details
326
+ - Base Model: YOLOv8s
327
+ - Epochs: {EPOCHS}
328
+ - Image Size: {IMG_SIZE}x{IMG_SIZE}
329
+ - Batch Size: {BATCH}
330
+ - Augmentations: Mosaic, HSV, scale, shear, flip
331
+
332
+ ## Compliance Detection
333
+ The model detects both PPE presence AND absence:
334
+ - `no_helmet`, `no_mask`, `no_vest` = violation classes
335
+ - `helmet`, `mask`, `vest` = compliance classes
336
+ """
337
+
338
+ api.upload_file(
339
+ path_or_fileobj=readme.encode(),
340
+ path_in_repo="README.md",
341
+ repo_id=repo_id,
342
+ repo_type="model",
343
+ )
344
+
345
+ print(f" Model pushed to https://huggingface.co/{repo_id}")
346
+
347
+
348
+ def main():
349
+ print("=" * 60)
350
+ print("IMPROVED PPE DETECTION TRAINING")
351
+ print("=" * 60)
352
+
353
+ ppe_dir = download_ppe_dataset()
354
+ keremberke_ds = load_keremberke_dataset()
355
+ keremberke_yolo_dir = Path("/app/keremberke_yolo")
356
+ convert_keremberke_to_yolo(keremberke_ds, keremberke_yolo_dir)
357
+ DATASET_DIR.mkdir(parents=True, exist_ok=True)
358
+ merge_datasets(ppe_dir, keremberke_yolo_dir, DATASET_DIR)
359
+ data_yaml = DATASET_DIR / "data.yaml"
360
+ results = train_model(data_yaml)
361
+
362
+ best_model = Path("/app/runs/ppe_improved/weights/best.pt")
363
+ if best_model.exists():
364
+ push_to_hub(best_model)
365
+ else:
366
+ print(f" WARNING: Best model not found at {best_model}")
367
+ for pt_file in Path("/app/runs").rglob("best.pt"):
368
+ print(f" Found: {pt_file}")
369
+ push_to_hub(pt_file)
370
+ break
371
+
372
+ print("=" * 60)
373
+ print("DONE!")
374
+ print("=" * 60)
375
+
376
+
377
+ if __name__ == "__main__":
378
+ main()