baskarmother commited on
Commit
d670cbb
·
verified ·
1 Parent(s): 78785c8

Upload train_ppe_fixed.py

Browse files
Files changed (1) hide show
  1. train_ppe_fixed.py +368 -0
train_ppe_fixed.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PPE Compliance Detection Training - FIXED VERSION
4
+ - Downloads keremberke dataset as ZIP files (script-based datasets no longer supported)
5
+ - Uses model.train() return value correctly (no results.best)
6
+ - Pushes best.pt to HuggingFace Hub after training
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import zipfile
12
+ import shutil
13
+ import json
14
+ from pathlib import Path
15
+ from huggingface_hub import hf_hub_download, HfApi
16
+ from PIL import Image
17
+ import yaml
18
+
19
+ HF_USERNAME = "baskarmother"
20
+ MODEL_ID = "yolov8s-ppe-construction-v2"
21
+ DATASET_DIR = Path("/app/combined_ppe_dataset")
22
+ EPOCHS = 150
23
+ IMG_SIZE = 640
24
+ BATCH = 16
25
+ DEVICE = "0"
26
+
27
+ UNIFIED_CLASSES = [
28
+ "person", "helmet", "vest", "mask", "gloves",
29
+ "safety_shoe", "goggles", "no_helmet", "no_mask",
30
+ "no_vest", "head", "barricade", "dumpster",
31
+ "excavators", "safety_net", "dump_truck", "truck", "wheel_loader",
32
+ ]
33
+
34
+
35
+ def download_ppe_dataset():
36
+ print("[1/5] Downloading 51ddhesh/PPE_Detection...")
37
+ zip_path = hf_hub_download(
38
+ repo_id="51ddhesh/PPE_Detection",
39
+ filename="PPE.zip",
40
+ repo_type="dataset",
41
+ cache_dir="/app/hf_cache",
42
+ local_dir="/app/downloads",
43
+ )
44
+ extract_dir = Path("/app/downloads/ppe_dataset")
45
+ extract_dir.mkdir(parents=True, exist_ok=True)
46
+ with zipfile.ZipFile(zip_path, 'r') as zf:
47
+ zf.extractall(extract_dir)
48
+ print(f" Extracted to {extract_dir}")
49
+ return extract_dir
50
+
51
+
52
+ def download_keremberke_dataset():
53
+ print("[2/5] Downloading keremberke/construction-safety-object-detection...")
54
+ download_dir = Path("/app/downloads/keremberke")
55
+ download_dir.mkdir(parents=True, exist_ok=True)
56
+
57
+ for split_file in ["data/train.zip", "data/valid.zip", "data/test.zip"]:
58
+ try:
59
+ path = hf_hub_download(
60
+ repo_id="keremberke/construction-safety-object-detection",
61
+ filename=split_file,
62
+ repo_type="dataset",
63
+ cache_dir="/app/hf_cache",
64
+ local_dir=str(download_dir),
65
+ )
66
+ extract_to = download_dir / split_file.replace("data/", "").replace(".zip", "")
67
+ extract_to.mkdir(parents=True, exist_ok=True)
68
+ with zipfile.ZipFile(path, 'r') as zf:
69
+ zf.extractall(extract_to)
70
+ print(f" Downloaded and extracted {split_file}")
71
+ except Exception as e:
72
+ print(f" Warning: Could not download {split_file}: {e}")
73
+
74
+ return download_dir
75
+
76
+
77
+ def convert_keremberke_to_yolo(raw_dir: Path, output_dir: Path):
78
+ print("[3/5] Converting keremberke dataset to YOLO format...")
79
+
80
+ class_map = {
81
+ "person": 0, "hardhat": 1, "mask": 3,
82
+ "no-hardhat": 7, "no-mask": 8, "no-safety vest": 9,
83
+ "gloves": 4, "safety shoes": 5, "safety vest": 2,
84
+ "barricade": 11, "dumpster": 12, "excavators": 13,
85
+ "safety net": 14, "dump truck": 15,
86
+ "mini-van": 0, "truck": 16, "wheel loader": 17,
87
+ }
88
+
89
+ for split in ["train", "valid", "test"]:
90
+ images_dir = output_dir / split / "images"
91
+ labels_dir = output_dir / split / "labels"
92
+ images_dir.mkdir(parents=True, exist_ok=True)
93
+ labels_dir.mkdir(parents=True, exist_ok=True)
94
+
95
+ raw_split_dir = raw_dir / split
96
+ if not raw_split_dir.exists():
97
+ print(f" WARNING: {raw_split_dir} not found, skipping")
98
+ continue
99
+
100
+ json_files = list(raw_split_dir.rglob("*.json"))
101
+ print(f" {split}: Found {len(json_files)} JSON files")
102
+
103
+ if not json_files:
104
+ img_files = []
105
+ for ext in ["*.jpg", "*.jpeg", "*.png"]:
106
+ img_files.extend(raw_split_dir.rglob(ext))
107
+ for img_path in img_files:
108
+ shutil.copy2(img_path, images_dir / f"keremberke_{img_path.name}")
109
+ print(f" {split}: Copied {len(img_files)} images (no labels)")
110
+ continue
111
+
112
+ for coco_file in json_files:
113
+ with open(coco_file) as f:
114
+ coco_data = json.load(f)
115
+
116
+ image_id_to_file = {}
117
+ image_id_to_size = {}
118
+ for img in coco_data.get("images", []):
119
+ image_id_to_file[img["id"]] = img["file_name"]
120
+ image_id_to_size[img["id"]] = (img.get("width", 640), img.get("height", 640))
121
+
122
+ cat_id_to_name = {}
123
+ for cat in coco_data.get("categories", []):
124
+ cat_id_to_name[cat["id"]] = cat["name"]
125
+
126
+ anns_by_img = {}
127
+ for ann in coco_data.get("annotations", []):
128
+ anns_by_img.setdefault(ann["image_id"], []).append(ann)
129
+
130
+ all_images = {}
131
+ for ext in ["*.jpg", "*.jpeg", "*.png"]:
132
+ for p in raw_split_dir.rglob(ext):
133
+ all_images[p.name] = p
134
+
135
+ processed = 0
136
+ for img_id, filename in image_id_to_file.items():
137
+ img_path = all_images.get(filename)
138
+ if not img_path:
139
+ continue
140
+
141
+ out_name = f"keremberke_{filename}"
142
+ shutil.copy2(img_path, images_dir / out_name)
143
+
144
+ w, h = image_id_to_size.get(img_id, (640, 640))
145
+ label_path = labels_dir / f"{out_name.rsplit('.', 1)[0]}.txt"
146
+
147
+ with open(label_path, "w") as f:
148
+ for ann in anns_by_img.get(img_id, []):
149
+ cat_name = cat_id_to_name.get(ann["category_id"], "")
150
+ if cat_name not in class_map:
151
+ continue
152
+ cls = class_map[cat_name]
153
+ x, y, bw, bh = ann["bbox"]
154
+ xc = (x + bw / 2) / w
155
+ yc = (y + bh / 2) / h
156
+ nw = bw / w
157
+ nh = bh / h
158
+ xc = max(0, min(1, xc))
159
+ yc = max(0, min(1, yc))
160
+ nw = max(0, min(1, nw))
161
+ nh = max(0, min(1, nh))
162
+ f.write(f"{cls} {xc:.6f} {yc:.6f} {nw:.6f} {nh:.6f}\n")
163
+ processed += 1
164
+
165
+ print(f" {split}: Processed {processed} images from {coco_file.name}")
166
+
167
+ print(f" Converted to {output_dir}")
168
+
169
+
170
+ def merge_datasets(ppe_extract_dir: Path, keremberke_dir: Path, output_dir: Path):
171
+ print("[4/5] Merging datasets...")
172
+ output_dir.mkdir(parents=True, exist_ok=True)
173
+
174
+ ppe_dir = None
175
+ for candidate in [ppe_extract_dir / "PPE", ppe_extract_dir / "ppe", ppe_extract_dir]:
176
+ if (candidate / "train" / "images").exists():
177
+ ppe_dir = candidate
178
+ break
179
+
180
+ if ppe_dir is None:
181
+ print(" ERROR: Could not find PPE dataset structure")
182
+ sys.exit(1)
183
+
184
+ print(f" Found PPE dataset at: {ppe_dir}")
185
+
186
+ ppe_class_map = {0: 2, 1: 5, 2: 3, 3: 1, 4: 6, 5: 4}
187
+
188
+ for split in ["train", "valid", "test"]:
189
+ out_images = output_dir / split / "images"
190
+ out_labels = output_dir / split / "labels"
191
+ out_images.mkdir(parents=True, exist_ok=True)
192
+ out_labels.mkdir(parents=True, exist_ok=True)
193
+
194
+ ppe_images = ppe_dir / split / "images"
195
+ ppe_labels = ppe_dir / split / "labels"
196
+ if ppe_images.exists():
197
+ for img_file in sorted(ppe_images.iterdir()):
198
+ if img_file.suffix.lower() not in [".jpg", ".jpeg", ".png"]:
199
+ continue
200
+ shutil.copy2(img_file, out_images / f"ppe_{img_file.name}")
201
+ label_file = ppe_labels / f"{img_file.stem}.txt"
202
+ if label_file.exists():
203
+ with open(label_file) as f:
204
+ lines = f.readlines()
205
+ remapped = []
206
+ for line in lines:
207
+ parts = line.strip().split()
208
+ if len(parts) < 5:
209
+ continue
210
+ src_cls = int(parts[0])
211
+ if src_cls in ppe_class_map:
212
+ remapped.append(f"{ppe_class_map[src_cls]} {' '.join(parts[1:])}\n")
213
+ out_label = out_labels / f"ppe_{img_file.stem}.txt"
214
+ with open(out_label, "w") as f:
215
+ f.writelines(remapped)
216
+
217
+ k_images = keremberke_dir / split / "images"
218
+ k_labels = keremberke_dir / split / "labels"
219
+ if k_images.exists():
220
+ for img_file in sorted(k_images.iterdir()):
221
+ shutil.copy2(img_file, out_images / img_file.name)
222
+ for label_file in sorted(k_labels.iterdir()):
223
+ shutil.copy2(label_file, out_labels / label_file.name)
224
+
225
+ data_yaml = {
226
+ "path": str(output_dir.absolute()),
227
+ "train": "train/images",
228
+ "val": "valid/images",
229
+ "test": "test/images",
230
+ "names": {i: name for i, name in enumerate(UNIFIED_CLASSES)},
231
+ "nc": len(UNIFIED_CLASSES),
232
+ }
233
+ with open(output_dir / "data.yaml", "w") as f:
234
+ yaml.dump(data_yaml, f, default_flow_style=False)
235
+
236
+ for split in ["train", "valid", "test"]:
237
+ n = len(list((output_dir / split / "images").glob("*")))
238
+ print(f" {split}: {n} images")
239
+
240
+
241
+ def train_model(data_yaml_path: Path):
242
+ print("[5/5] Training YOLOv8s...")
243
+ from ultralytics import YOLO
244
+
245
+ model = YOLO("yolov8s.pt")
246
+
247
+ # model.train() returns metrics, NOT the model object
248
+ # The trained weights are auto-saved to /app/runs/ppe_improved/weights/best.pt
249
+ model.train(
250
+ data=str(data_yaml_path),
251
+ epochs=EPOCHS,
252
+ imgsz=IMG_SIZE,
253
+ batch=BATCH,
254
+ device=DEVICE,
255
+ patience=30,
256
+ project="/app/runs",
257
+ name="ppe_improved",
258
+ exist_ok=True,
259
+ pretrained=True,
260
+ optimizer="SGD",
261
+ lr0=0.01,
262
+ lrf=0.01,
263
+ momentum=0.9,
264
+ weight_decay=0.0005,
265
+ augment=True,
266
+ mosaic=1.0,
267
+ hsv_h=0.015,
268
+ hsv_s=0.7,
269
+ hsv_v=0.4,
270
+ degrees=5.0,
271
+ translate=0.1,
272
+ scale=0.5,
273
+ shear=2.0,
274
+ perspective=0.0,
275
+ flipud=0.0,
276
+ fliplr=0.5,
277
+ )
278
+
279
+ print(" Training complete!")
280
+ best_model = Path("/app/runs/ppe_improved/weights/best.pt")
281
+ print(f" Best model saved at: {best_model} (exists={best_model.exists()})")
282
+ return best_model
283
+
284
+
285
+ def push_to_hub(best_model_path: Path):
286
+ print("Pushing model to HuggingFace Hub...")
287
+ api = HfApi()
288
+ repo_id = f"{HF_USERNAME}/{MODEL_ID}"
289
+
290
+ try:
291
+ api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
292
+ except Exception as e:
293
+ print(f" Repo info: {e}")
294
+
295
+ api.upload_file(
296
+ path_or_fileobj=str(best_model_path),
297
+ path_in_repo="best.pt",
298
+ repo_id=repo_id,
299
+ repo_type="model",
300
+ )
301
+
302
+ readme = f"""---
303
+ license: cc-by-4.0
304
+ library_name: ultralytics
305
+ tags:
306
+ - object-detection
307
+ - ppe
308
+ - construction-safety
309
+ - yolov8
310
+ ---
311
+
312
+ # {MODEL_ID}
313
+
314
+ Improved PPE Compliance Detection Model for Construction Sites (v2)
315
+
316
+ ## Classes ({len(UNIFIED_CLASSES)})
317
+ {chr(10).join(f"- {i}: {name}" for i, name in enumerate(UNIFIED_CLASSES))}
318
+
319
+ ## Usage
320
+ ```python
321
+ from ultralytics import YOLO
322
+ model = YOLO("hf://{repo_id}/best.pt")
323
+ results = model.predict("image.jpg")
324
+ ```
325
+
326
+ ## Training Details
327
+ - Base Model: YOLOv8s
328
+ - Epochs: {EPOCHS}
329
+ - Image Size: {IMG_SIZE}x{IMG_SIZE}
330
+ - Batch Size: {BATCH}
331
+ """
332
+ api.upload_file(
333
+ path_or_fileobj=readme.encode(),
334
+ path_in_repo="README.md",
335
+ repo_id=repo_id,
336
+ repo_type="model",
337
+ )
338
+ print(f" Model pushed to https://huggingface.co/{repo_id}")
339
+
340
+
341
+ def main():
342
+ print("=" * 60)
343
+ print("IMPROVED PPE DETECTION TRAINING (FIXED)")
344
+ print("=" * 60)
345
+
346
+ ppe_dir = download_ppe_dataset()
347
+ keremberke_raw = download_keremberke_dataset()
348
+ keremberke_yolo = Path("/app/keremberke_yolo")
349
+ convert_keremberke_to_yolo(keremberke_raw, keremberke_yolo)
350
+ DATASET_DIR.mkdir(parents=True, exist_ok=True)
351
+ merge_datasets(ppe_dir, keremberke_yolo, DATASET_DIR)
352
+ best_model = train_model(DATASET_DIR / "data.yaml")
353
+
354
+ if best_model.exists():
355
+ push_to_hub(best_model)
356
+ else:
357
+ print(f" WARNING: Best model not found at {best_model}")
358
+ for pt in Path("/app/runs").rglob("best.pt"):
359
+ push_to_hub(pt)
360
+ break
361
+
362
+ print("=" * 60)
363
+ print("DONE!")
364
+ print("=" * 60)
365
+
366
+
367
+ if __name__ == "__main__":
368
+ main()