HZSDU commited on
Commit
d9b768c
·
verified ·
1 Parent(s): ac03b97

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. yolov8_model/ultralytics/data/__init__.py +15 -0
  2. yolov8_model/ultralytics/data/__pycache__/__init__.cpython-310.pyc +0 -0
  3. yolov8_model/ultralytics/data/__pycache__/augment.cpython-310.pyc +0 -0
  4. yolov8_model/ultralytics/data/__pycache__/base.cpython-310.pyc +0 -0
  5. yolov8_model/ultralytics/data/__pycache__/build.cpython-310.pyc +0 -0
  6. yolov8_model/ultralytics/data/__pycache__/converter.cpython-310.pyc +0 -0
  7. yolov8_model/ultralytics/data/__pycache__/dataset.cpython-310.pyc +0 -0
  8. yolov8_model/ultralytics/data/__pycache__/loaders.cpython-310.pyc +0 -0
  9. yolov8_model/ultralytics/data/__pycache__/utils.cpython-310.pyc +0 -0
  10. yolov8_model/ultralytics/data/dataset.py +375 -0
  11. yolov8_model/ultralytics/data/explorer/__init__.py +5 -0
  12. yolov8_model/ultralytics/data/explorer/__pycache__/__init__.cpython-310.pyc +0 -0
  13. yolov8_model/ultralytics/data/explorer/__pycache__/explorer.cpython-310.pyc +0 -0
  14. yolov8_model/ultralytics/data/explorer/__pycache__/utils.cpython-310.pyc +0 -0
  15. yolov8_model/ultralytics/data/explorer/explorer.py +471 -0
  16. yolov8_model/ultralytics/data/explorer/gui/__init__.py +1 -0
  17. yolov8_model/ultralytics/data/explorer/gui/dash.py +268 -0
  18. yolov8_model/ultralytics/data/explorer/utils.py +166 -0
  19. yolov8_model/ultralytics/data/loaders.py +533 -0
  20. yolov8_model/ultralytics/data/scripts/download_weights.sh +18 -0
  21. yolov8_model/ultralytics/data/scripts/get_coco.sh +60 -0
  22. yolov8_model/ultralytics/data/scripts/get_coco128.sh +17 -0
  23. yolov8_model/ultralytics/data/scripts/get_imagenet.sh +51 -0
  24. yolov8_model/ultralytics/data/split_dota.py +288 -0
  25. yolov8_model/ultralytics/data/utils.py +647 -0
  26. yolov8_model/ultralytics/engine/__init__.py +1 -0
  27. yolov8_model/ultralytics/engine/__pycache__/__init__.cpython-310.pyc +0 -0
  28. yolov8_model/ultralytics/engine/__pycache__/exporter.cpython-310.pyc +0 -0
  29. yolov8_model/ultralytics/engine/__pycache__/model.cpython-310.pyc +0 -0
  30. yolov8_model/ultralytics/engine/__pycache__/predictor.cpython-310.pyc +0 -0
  31. yolov8_model/ultralytics/engine/__pycache__/results.cpython-310.pyc +0 -0
  32. yolov8_model/ultralytics/engine/__pycache__/trainer.cpython-310.pyc +0 -0
  33. yolov8_model/ultralytics/engine/__pycache__/validator.cpython-310.pyc +0 -0
  34. yolov8_model/ultralytics/engine/exporter.py +1099 -0
  35. yolov8_model/ultralytics/engine/model.py +772 -0
  36. yolov8_model/ultralytics/engine/predictor.py +407 -0
  37. yolov8_model/ultralytics/engine/results.py +680 -0
  38. yolov8_model/ultralytics/engine/trainer.py +755 -0
  39. yolov8_model/ultralytics/engine/tuner.py +240 -0
  40. yolov8_model/ultralytics/engine/validator.py +336 -0
  41. yolov8_model/ultralytics/hub/__init__.py +128 -0
  42. yolov8_model/ultralytics/hub/__pycache__/__init__.cpython-310.pyc +0 -0
  43. yolov8_model/ultralytics/hub/__pycache__/auth.cpython-310.pyc +0 -0
  44. yolov8_model/ultralytics/hub/__pycache__/utils.cpython-310.pyc +0 -0
  45. yolov8_model/ultralytics/hub/auth.py +136 -0
  46. yolov8_model/ultralytics/hub/session.py +348 -0
  47. yolov8_model/ultralytics/hub/utils.py +247 -0
  48. yolov8_model/ultralytics/models/__init__.py +7 -0
  49. yolov8_model/ultralytics/models/fastsam/__init__.py +8 -0
  50. yolov8_model/ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc +0 -0
yolov8_model/ultralytics/data/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .base import BaseDataset
4
+ from .build import build_dataloader, build_yolo_dataset, load_inference_source
5
+ from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
6
+
7
+ __all__ = (
8
+ "BaseDataset",
9
+ "ClassificationDataset",
10
+ "SemanticDataset",
11
+ "YOLODataset",
12
+ "build_yolo_dataset",
13
+ "build_dataloader",
14
+ "load_inference_source",
15
+ )
yolov8_model/ultralytics/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (470 Bytes). View file
 
yolov8_model/ultralytics/data/__pycache__/augment.cpython-310.pyc ADDED
Binary file (44.5 kB). View file
 
yolov8_model/ultralytics/data/__pycache__/base.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
yolov8_model/ultralytics/data/__pycache__/build.cpython-310.pyc ADDED
Binary file (6.24 kB). View file
 
yolov8_model/ultralytics/data/__pycache__/converter.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
yolov8_model/ultralytics/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
yolov8_model/ultralytics/data/__pycache__/loaders.cpython-310.pyc ADDED
Binary file (20.4 kB). View file
 
yolov8_model/ultralytics/data/__pycache__/utils.cpython-310.pyc ADDED
Binary file (26.7 kB). View file
 
yolov8_model/ultralytics/data/dataset.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ import contextlib
3
+ from itertools import repeat
4
+ from multiprocessing.pool import ThreadPool
5
+ from pathlib import Path
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ from PIL import Image
12
+
13
+ from yolov8_model.ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
14
+ from yolov8_model.ultralytics.utils.ops import resample_segments
15
+ from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
16
+ from .base import BaseDataset
17
+ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
18
+
19
+ # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
20
+ DATASET_CACHE_VERSION = "1.0.3"
21
+
22
+
23
+ class YOLODataset(BaseDataset):
24
+ """
25
+ Dataset class for loading object detection and/or segmentation labels in YOLO format.
26
+
27
+ Args:
28
+ data (dict, optional): A dataset YAML dictionary. Defaults to None.
29
+ task (str): An explicit arg to point current task, Defaults to 'detect'.
30
+
31
+ Returns:
32
+ (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
33
+ """
34
+
35
+ def __init__(self, *args, data=None, task="detect", **kwargs):
36
+ """Initializes the YOLODataset with optional configurations for segments and keypoints."""
37
+ self.use_segments = task == "segment"
38
+ self.use_keypoints = task == "pose"
39
+ self.use_obb = task == "obb"
40
+ self.data = data
41
+ assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
42
+ super().__init__(*args, **kwargs)
43
+
44
+ def cache_labels(self, path=Path("./labels.cache")):
45
+ """
46
+ Cache dataset labels, check images and read shapes.
47
+
48
+ Args:
49
+ path (Path): Path where to save the cache file. Default is Path('./labels.cache').
50
+
51
+ Returns:
52
+ (dict): labels.
53
+ """
54
+ x = {"labels": []}
55
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
56
+ desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
57
+ total = len(self.im_files)
58
+ nkpt, ndim = self.data.get("kpt_shape", (0, 0))
59
+ if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
60
+ raise ValueError(
61
+ "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
62
+ "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
63
+ )
64
+ with ThreadPool(NUM_THREADS) as pool:
65
+ results = pool.imap(
66
+ func=verify_image_label,
67
+ iterable=zip(
68
+ self.im_files,
69
+ self.label_files,
70
+ repeat(self.prefix),
71
+ repeat(self.use_keypoints),
72
+ repeat(len(self.data["names"])),
73
+ repeat(nkpt),
74
+ repeat(ndim),
75
+ ),
76
+ )
77
+ pbar = TQDM(results, desc=desc, total=total)
78
+ for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
79
+ nm += nm_f
80
+ nf += nf_f
81
+ ne += ne_f
82
+ nc += nc_f
83
+ if im_file:
84
+ x["labels"].append(
85
+ dict(
86
+ im_file=im_file,
87
+ shape=shape,
88
+ cls=lb[:, 0:1], # n, 1
89
+ bboxes=lb[:, 1:], # n, 4
90
+ segments=segments,
91
+ keypoints=keypoint,
92
+ normalized=True,
93
+ bbox_format="xywh",
94
+ )
95
+ )
96
+ if msg:
97
+ msgs.append(msg)
98
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
99
+ pbar.close()
100
+
101
+ if msgs:
102
+ LOGGER.info("\n".join(msgs))
103
+ if nf == 0:
104
+ LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
105
+ x["hash"] = get_hash(self.label_files + self.im_files)
106
+ x["results"] = nf, nm, ne, nc, len(self.im_files)
107
+ x["msgs"] = msgs # warnings
108
+ save_dataset_cache_file(self.prefix, path, x)
109
+ return x
110
+
111
+ def get_labels(self):
112
+ """Returns dictionary of labels for YOLO training."""
113
+ self.label_files = img2label_paths(self.im_files)
114
+ cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
115
+ try:
116
+ cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
117
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
118
+ assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
119
+ except (FileNotFoundError, AssertionError, AttributeError):
120
+ cache, exists = self.cache_labels(cache_path), False # run cache ops
121
+
122
+ # Display cache
123
+ nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
124
+ if exists and LOCAL_RANK in (-1, 0):
125
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
126
+ TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
127
+ if cache["msgs"]:
128
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
129
+
130
+ # Read cache
131
+ [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
132
+ labels = cache["labels"]
133
+ if not labels:
134
+ LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
135
+ self.im_files = [lb["im_file"] for lb in labels] # update im_files
136
+
137
+ # Check if the dataset is all boxes or all segments
138
+ lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
139
+ len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
140
+ if len_segments and len_boxes != len_segments:
141
+ LOGGER.warning(
142
+ f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
143
+ f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
144
+ "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
145
+ )
146
+ for lb in labels:
147
+ lb["segments"] = []
148
+ if len_cls == 0:
149
+ LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
150
+ return labels
151
+
152
+ def build_transforms(self, hyp=None):
153
+ """Builds and appends transforms to the list."""
154
+ if self.augment:
155
+ hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
156
+ hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
157
+ transforms = v8_transforms(self, self.imgsz, hyp)
158
+ else:
159
+ transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
160
+ transforms.append(
161
+ Format(
162
+ bbox_format="xywh",
163
+ normalize=True,
164
+ return_mask=self.use_segments,
165
+ return_keypoint=self.use_keypoints,
166
+ return_obb=self.use_obb,
167
+ batch_idx=True,
168
+ mask_ratio=hyp.mask_ratio,
169
+ mask_overlap=hyp.overlap_mask,
170
+ )
171
+ )
172
+ return transforms
173
+
174
+ def close_mosaic(self, hyp):
175
+ """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
176
+ hyp.mosaic = 0.0 # set mosaic ratio=0.0
177
+ hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
178
+ hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
179
+ self.transforms = self.build_transforms(hyp)
180
+
181
+ def update_labels_info(self, label):
182
+ """
183
+ Custom your label format here.
184
+
185
+ Note:
186
+ cls is not with bboxes now, classification and semantic segmentation need an independent cls label
187
+ Can also support classification and semantic segmentation by adding or removing dict keys there.
188
+ """
189
+ bboxes = label.pop("bboxes")
190
+ segments = label.pop("segments", [])
191
+ keypoints = label.pop("keypoints", None)
192
+ bbox_format = label.pop("bbox_format")
193
+ normalized = label.pop("normalized")
194
+
195
+ # NOTE: do NOT resample oriented boxes
196
+ segment_resamples = 100 if self.use_obb else 1000
197
+ if len(segments) > 0:
198
+ # list[np.array(1000, 2)] * num_samples
199
+ # (N, 1000, 2)
200
+ segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
201
+ else:
202
+ segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
203
+ label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
204
+ return label
205
+
206
+ @staticmethod
207
+ def collate_fn(batch):
208
+ """Collates data samples into batches."""
209
+ new_batch = {}
210
+ keys = batch[0].keys()
211
+ values = list(zip(*[list(b.values()) for b in batch]))
212
+ for i, k in enumerate(keys):
213
+ value = values[i]
214
+ if k == "img":
215
+ value = torch.stack(value, 0)
216
+ if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
217
+ value = torch.cat(value, 0)
218
+ new_batch[k] = value
219
+ new_batch["batch_idx"] = list(new_batch["batch_idx"])
220
+ for i in range(len(new_batch["batch_idx"])):
221
+ new_batch["batch_idx"][i] += i # add target image index for build_targets()
222
+ new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
223
+ return new_batch
224
+
225
+
226
+ # Classification dataloaders -------------------------------------------------------------------------------------------
227
+ class ClassificationDataset(torchvision.datasets.ImageFolder):
228
+ """
229
+ YOLO Classification Dataset.
230
+
231
+ Args:
232
+ root (str): Dataset path.
233
+
234
+ Attributes:
235
+ cache_ram (bool): True if images should be cached in RAM, False otherwise.
236
+ cache_disk (bool): True if images should be cached on disk, False otherwise.
237
+ samples (list): List of samples containing file, index, npy, and im.
238
+ torch_transforms (callable): torchvision transforms applied to the dataset.
239
+ album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
240
+ """
241
+
242
+ def __init__(self, root, args, augment=False, cache=False, prefix=""):
243
+ """
244
+ Initialize YOLO object with root, image size, augmentations, and cache settings.
245
+
246
+ Args:
247
+ root (str): Dataset path.
248
+ args (Namespace): Argument parser containing dataset related settings.
249
+ augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
250
+ cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
251
+ """
252
+ super().__init__(root=root)
253
+ if augment and args.fraction < 1.0: # reduce training fraction
254
+ self.samples = self.samples[: round(len(self.samples) * args.fraction)]
255
+ self.prefix = colorstr(f"{prefix}: ") if prefix else ""
256
+ self.cache_ram = cache is True or cache == "ram"
257
+ self.cache_disk = cache == "disk"
258
+ self.samples = self.verify_images() # filter out bad images
259
+ self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
260
+ scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
261
+ self.torch_transforms = (
262
+ classify_augmentations(
263
+ size=args.imgsz,
264
+ scale=scale,
265
+ hflip=args.fliplr,
266
+ vflip=args.flipud,
267
+ erasing=args.erasing,
268
+ auto_augment=args.auto_augment,
269
+ hsv_h=args.hsv_h,
270
+ hsv_s=args.hsv_s,
271
+ hsv_v=args.hsv_v,
272
+ )
273
+ if augment
274
+ else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
275
+ )
276
+
277
+ def __getitem__(self, i):
278
+ """Returns subset of data and targets corresponding to given indices."""
279
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
280
+ if self.cache_ram and im is None:
281
+ im = self.samples[i][3] = cv2.imread(f)
282
+ elif self.cache_disk:
283
+ if not fn.exists(): # load npy
284
+ np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
285
+ im = np.load(fn)
286
+ else: # read image
287
+ im = cv2.imread(f) # BGR
288
+ # Convert NumPy array to PIL image
289
+ im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
290
+ sample = self.torch_transforms(im)
291
+ return {"img": sample, "cls": j}
292
+
293
+ def __len__(self) -> int:
294
+ """Return the total number of samples in the dataset."""
295
+ return len(self.samples)
296
+
297
+ def verify_images(self):
298
+ """Verify all images in dataset."""
299
+ desc = f"{self.prefix}Scanning {self.root}..."
300
+ path = Path(self.root).with_suffix(".cache") # *.cache file path
301
+
302
+ with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
303
+ cache = load_dataset_cache_file(path) # attempt to load a *.cache file
304
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
305
+ assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
306
+ nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
307
+ if LOCAL_RANK in (-1, 0):
308
+ d = f"{desc} {nf} images, {nc} corrupt"
309
+ TQDM(None, desc=d, total=n, initial=n)
310
+ if cache["msgs"]:
311
+ LOGGER.info("\n".join(cache["msgs"])) # display warnings
312
+ return samples
313
+
314
+ # Run scan if *.cache retrieval failed
315
+ nf, nc, msgs, samples, x = 0, 0, [], [], {}
316
+ with ThreadPool(NUM_THREADS) as pool:
317
+ results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
318
+ pbar = TQDM(results, desc=desc, total=len(self.samples))
319
+ for sample, nf_f, nc_f, msg in pbar:
320
+ if nf_f:
321
+ samples.append(sample)
322
+ if msg:
323
+ msgs.append(msg)
324
+ nf += nf_f
325
+ nc += nc_f
326
+ pbar.desc = f"{desc} {nf} images, {nc} corrupt"
327
+ pbar.close()
328
+ if msgs:
329
+ LOGGER.info("\n".join(msgs))
330
+ x["hash"] = get_hash([x[0] for x in self.samples])
331
+ x["results"] = nf, nc, len(samples), samples
332
+ x["msgs"] = msgs # warnings
333
+ save_dataset_cache_file(self.prefix, path, x)
334
+ return samples
335
+
336
+
337
+ def load_dataset_cache_file(path):
338
+ """Load an Ultralytics *.cache dictionary from path."""
339
+ import gc
340
+
341
+ gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
342
+ cache = np.load(str(path), allow_pickle=True).item() # load dict
343
+ gc.enable()
344
+ return cache
345
+
346
+
347
+ def save_dataset_cache_file(prefix, path, x):
348
+ """Save an Ultralytics dataset *.cache dictionary x to path."""
349
+ x["version"] = DATASET_CACHE_VERSION # add cache version
350
+ if is_dir_writeable(path.parent):
351
+ if path.exists():
352
+ path.unlink() # remove *.cache file if exists
353
+ np.save(str(path), x) # save cache for next time
354
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
355
+ LOGGER.info(f"{prefix}New cache created: {path}")
356
+ else:
357
+ LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
358
+
359
+
360
+ # TODO: support semantic segmentation
361
+ class SemanticDataset(BaseDataset):
362
+ """
363
+ Semantic Segmentation Dataset.
364
+
365
+ This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
366
+ from the BaseDataset class.
367
+
368
+ Note:
369
+ This class is currently a placeholder and needs to be populated with methods and attributes for supporting
370
+ semantic segmentation tasks.
371
+ """
372
+
373
+ def __init__(self):
374
+ """Initialize a SemanticDataset object."""
375
+ super().__init__()
yolov8_model/ultralytics/data/explorer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .utils import plot_query_result
4
+
5
+ __all__ = ["plot_query_result"]
yolov8_model/ultralytics/data/explorer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (248 Bytes). View file
 
yolov8_model/ultralytics/data/explorer/__pycache__/explorer.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
yolov8_model/ultralytics/data/explorer/__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.39 kB). View file
 
yolov8_model/ultralytics/data/explorer/explorer.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from io import BytesIO
4
+ from pathlib import Path
5
+ from typing import Any, List, Tuple, Union
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from matplotlib import pyplot as plt
12
+ from pandas import DataFrame
13
+ from tqdm import tqdm
14
+
15
+ from yolov8_model.ultralytics.data.augment import Format
16
+ from yolov8_model.ultralytics.data.dataset import YOLODataset
17
+ from yolov8_model.ultralytics.data.utils import check_det_dataset
18
+ from yolov8_model.ultralytics.models.yolo.model import YOLO
19
+ from yolov8_model.ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR
20
+ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
21
+
22
+
23
+ class ExplorerDataset(YOLODataset):
24
+ def __init__(self, *args, data: dict = None, **kwargs) -> None:
25
+ super().__init__(*args, data=data, **kwargs)
26
+
27
+ def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
28
+ """Loads 1 image from dataset index 'i' without any resize ops."""
29
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
30
+ if im is None: # not cached in RAM
31
+ if fn.exists(): # load npy
32
+ im = np.load(fn)
33
+ else: # read image
34
+ im = cv2.imread(f) # BGR
35
+ if im is None:
36
+ raise FileNotFoundError(f"Image Not Found {f}")
37
+ h0, w0 = im.shape[:2] # orig hw
38
+ return im, (h0, w0), im.shape[:2]
39
+
40
+ return self.ims[i], self.im_hw0[i], self.im_hw[i]
41
+
42
+ def build_transforms(self, hyp: IterableSimpleNamespace = None):
43
+ """Creates transforms for dataset images without resizing."""
44
+ return Format(
45
+ bbox_format="xyxy",
46
+ normalize=False,
47
+ return_mask=self.use_segments,
48
+ return_keypoint=self.use_keypoints,
49
+ batch_idx=True,
50
+ mask_ratio=hyp.mask_ratio,
51
+ mask_overlap=hyp.overlap_mask,
52
+ )
53
+
54
+
55
+ class Explorer:
56
+ def __init__(
57
+ self,
58
+ data: Union[str, Path] = "coco128.yaml",
59
+ model: str = "yolov8n.pt",
60
+ uri: str = USER_CONFIG_DIR / "explorer",
61
+ ) -> None:
62
+ checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
63
+ import lancedb
64
+
65
+ self.connection = lancedb.connect(uri)
66
+ self.table_name = Path(data).name.lower() + "_" + model.lower()
67
+ self.sim_idx_base_name = (
68
+ f"{self.table_name}_sim_idx".lower()
69
+ ) # Use this name and append thres and top_k to reuse the table
70
+ self.model = YOLO(model)
71
+ self.data = data # None
72
+ self.choice_set = None
73
+
74
+ self.table = None
75
+ self.progress = 0
76
+
77
+ def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
78
+ """
79
+ Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
80
+ already exists. Pass force=True to overwrite the existing table.
81
+
82
+ Args:
83
+ force (bool): Whether to overwrite the existing table or not. Defaults to False.
84
+ split (str): Split of the dataset to use. Defaults to 'train'.
85
+
86
+ Example:
87
+ ```python
88
+ exp = Explorer()
89
+ exp.create_embeddings_table()
90
+ ```
91
+ """
92
+ if self.table is not None and not force:
93
+ LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
94
+ return
95
+ if self.table_name in self.connection.table_names() and not force:
96
+ LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
97
+ self.table = self.connection.open_table(self.table_name)
98
+ self.progress = 1
99
+ return
100
+ if self.data is None:
101
+ raise ValueError("Data must be provided to create embeddings table")
102
+
103
+ data_info = check_det_dataset(self.data)
104
+ if split not in data_info:
105
+ raise ValueError(
106
+ f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
107
+ )
108
+
109
+ choice_set = data_info[split]
110
+ choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
111
+ self.choice_set = choice_set
112
+ dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
113
+
114
+ # Create the table schema
115
+ batch = dataset[0]
116
+ vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
117
+ table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
118
+ table.add(
119
+ self._yield_batches(
120
+ dataset,
121
+ data_info,
122
+ self.model,
123
+ exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
124
+ )
125
+ )
126
+
127
+ self.table = table
128
+
129
+ def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
130
+ """Generates batches of data for embedding, excluding specified keys."""
131
+ for i in tqdm(range(len(dataset))):
132
+ self.progress = float(i + 1) / len(dataset)
133
+ batch = dataset[i]
134
+ for k in exclude_keys:
135
+ batch.pop(k, None)
136
+ batch = sanitize_batch(batch, data_info)
137
+ batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
138
+ yield [batch]
139
+
140
+ def query(
141
+ self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
142
+ ) -> Any: # pyarrow.Table
143
+ """
144
+ Query the table for similar images. Accepts a single image or a list of images.
145
+
146
+ Args:
147
+ imgs (str or list): Path to the image or a list of paths to the images.
148
+ limit (int): Number of results to return.
149
+
150
+ Returns:
151
+ (pyarrow.Table): An arrow table containing the results. Supports converting to:
152
+ - pandas dataframe: `result.to_pandas()`
153
+ - dict of lists: `result.to_pydict()`
154
+
155
+ Example:
156
+ ```python
157
+ exp = Explorer()
158
+ exp.create_embeddings_table()
159
+ similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
160
+ ```
161
+ """
162
+ if self.table is None:
163
+ raise ValueError("Table is not created. Please create the table first.")
164
+ if isinstance(imgs, str):
165
+ imgs = [imgs]
166
+ assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
167
+ embeds = self.model.embed(imgs)
168
+ # Get avg if multiple images are passed (len > 1)
169
+ embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
170
+ return self.table.search(embeds).limit(limit).to_arrow()
171
+
172
+ def sql_query(
173
+ self, query: str, return_type: str = "pandas"
174
+ ) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
175
+ """
176
+ Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
177
+
178
+ Args:
179
+ query (str): SQL query to run.
180
+ return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
181
+
182
+ Returns:
183
+ (pyarrow.Table): An arrow table containing the results.
184
+
185
+ Example:
186
+ ```python
187
+ exp = Explorer()
188
+ exp.create_embeddings_table()
189
+ query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
190
+ result = exp.sql_query(query)
191
+ ```
192
+ """
193
+ assert return_type in {
194
+ "pandas",
195
+ "arrow",
196
+ }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
197
+ import duckdb
198
+
199
+ if self.table is None:
200
+ raise ValueError("Table is not created. Please create the table first.")
201
+
202
+ # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
203
+ table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
204
+ if not query.startswith("SELECT") and not query.startswith("WHERE"):
205
+ raise ValueError(
206
+ f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
207
+ )
208
+ if query.startswith("WHERE"):
209
+ query = f"SELECT * FROM 'table' {query}"
210
+ LOGGER.info(f"Running query: {query}")
211
+
212
+ rs = duckdb.sql(query)
213
+ if return_type == "arrow":
214
+ return rs.arrow()
215
+ elif return_type == "pandas":
216
+ return rs.df()
217
+
218
+ def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
219
+ """
220
+ Plot the results of a SQL-Like query on the table.
221
+ Args:
222
+ query (str): SQL query to run.
223
+ labels (bool): Whether to plot the labels or not.
224
+
225
+ Returns:
226
+ (PIL.Image): Image containing the plot.
227
+
228
+ Example:
229
+ ```python
230
+ exp = Explorer()
231
+ exp.create_embeddings_table()
232
+ query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
233
+ result = exp.plot_sql_query(query)
234
+ ```
235
+ """
236
+ result = self.sql_query(query, return_type="arrow")
237
+ if len(result) == 0:
238
+ LOGGER.info("No results found.")
239
+ return None
240
+ img = plot_query_result(result, plot_labels=labels)
241
+ return Image.fromarray(img)
242
+
243
+ def get_similar(
244
+ self,
245
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
246
+ idx: Union[int, List[int]] = None,
247
+ limit: int = 25,
248
+ return_type: str = "pandas",
249
+ ) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
250
+ """
251
+ Query the table for similar images. Accepts a single image or a list of images.
252
+
253
+ Args:
254
+ img (str or list): Path to the image or a list of paths to the images.
255
+ idx (int or list): Index of the image in the table or a list of indexes.
256
+ limit (int): Number of results to return. Defaults to 25.
257
+ return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
258
+
259
+ Returns:
260
+ (pandas.DataFrame): A dataframe containing the results.
261
+
262
+ Example:
263
+ ```python
264
+ exp = Explorer()
265
+ exp.create_embeddings_table()
266
+ similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
267
+ ```
268
+ """
269
+ assert return_type in {
270
+ "pandas",
271
+ "arrow",
272
+ }, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
273
+ img = self._check_imgs_or_idxs(img, idx)
274
+ similar = self.query(img, limit=limit)
275
+
276
+ if return_type == "arrow":
277
+ return similar
278
+ elif return_type == "pandas":
279
+ return similar.to_pandas()
280
+
281
+ def plot_similar(
282
+ self,
283
+ img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
284
+ idx: Union[int, List[int]] = None,
285
+ limit: int = 25,
286
+ labels: bool = True,
287
+ ) -> Image.Image:
288
+ """
289
+ Plot the similar images. Accepts images or indexes.
290
+
291
+ Args:
292
+ img (str or list): Path to the image or a list of paths to the images.
293
+ idx (int or list): Index of the image in the table or a list of indexes.
294
+ labels (bool): Whether to plot the labels or not.
295
+ limit (int): Number of results to return. Defaults to 25.
296
+
297
+ Returns:
298
+ (PIL.Image): Image containing the plot.
299
+
300
+ Example:
301
+ ```python
302
+ exp = Explorer()
303
+ exp.create_embeddings_table()
304
+ similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
305
+ ```
306
+ """
307
+ similar = self.get_similar(img, idx, limit, return_type="arrow")
308
+ if len(similar) == 0:
309
+ LOGGER.info("No results found.")
310
+ return None
311
+ img = plot_query_result(similar, plot_labels=labels)
312
+ return Image.fromarray(img)
313
+
314
+ def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
315
+ """
316
+ Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
317
+ are max_dist or closer to the image in the embedding space at a given index.
318
+
319
+ Args:
320
+ max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
321
+ top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
322
+ vector search. Defaults: None.
323
+ force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
324
+
325
+ Returns:
326
+ (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
327
+ include indices of similar images and their respective distances.
328
+
329
+ Example:
330
+ ```python
331
+ exp = Explorer()
332
+ exp.create_embeddings_table()
333
+ sim_idx = exp.similarity_index()
334
+ ```
335
+ """
336
+ if self.table is None:
337
+ raise ValueError("Table is not created. Please create the table first.")
338
+ sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
339
+ if sim_idx_table_name in self.connection.table_names() and not force:
340
+ LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
341
+ return self.connection.open_table(sim_idx_table_name).to_pandas()
342
+
343
+ if top_k and not (1.0 >= top_k >= 0.0):
344
+ raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
345
+ if max_dist < 0.0:
346
+ raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
347
+
348
+ top_k = int(top_k * len(self.table)) if top_k else len(self.table)
349
+ top_k = max(top_k, 1)
350
+ features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
351
+ im_files = features["im_file"]
352
+ embeddings = features["vector"]
353
+
354
+ sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
355
+
356
+ def _yield_sim_idx():
357
+ """Generates a dataframe with similarity indices and distances for images."""
358
+ for i in tqdm(range(len(embeddings))):
359
+ sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
360
+ yield [
361
+ {
362
+ "idx": i,
363
+ "im_file": im_files[i],
364
+ "count": len(sim_idx),
365
+ "sim_im_files": sim_idx["im_file"].tolist(),
366
+ }
367
+ ]
368
+
369
+ sim_table.add(_yield_sim_idx())
370
+ self.sim_index = sim_table
371
+ return sim_table.to_pandas()
372
+
373
+ def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
374
+ """
375
+ Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
376
+ max_dist or closer to the image in the embedding space at a given index.
377
+
378
+ Args:
379
+ max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
380
+ top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
381
+ running vector search. Defaults to 0.01.
382
+ force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
383
+
384
+ Returns:
385
+ (PIL.Image): Image containing the plot.
386
+
387
+ Example:
388
+ ```python
389
+ exp = Explorer()
390
+ exp.create_embeddings_table()
391
+
392
+ similarity_idx_plot = exp.plot_similarity_index()
393
+ similarity_idx_plot.show() # view image preview
394
+ similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
395
+ ```
396
+ """
397
+ sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
398
+ sim_count = sim_idx["count"].tolist()
399
+ sim_count = np.array(sim_count)
400
+
401
+ indices = np.arange(len(sim_count))
402
+
403
+ # Create the bar plot
404
+ plt.bar(indices, sim_count)
405
+
406
+ # Customize the plot (optional)
407
+ plt.xlabel("data idx")
408
+ plt.ylabel("Count")
409
+ plt.title("Similarity Count")
410
+ buffer = BytesIO()
411
+ plt.savefig(buffer, format="png")
412
+ buffer.seek(0)
413
+
414
+ # Use Pillow to open the image from the buffer
415
+ return Image.fromarray(np.array(Image.open(buffer)))
416
+
417
+ def _check_imgs_or_idxs(
418
+ self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
419
+ ) -> List[np.ndarray]:
420
+ if img is None and idx is None:
421
+ raise ValueError("Either img or idx must be provided.")
422
+ if img is not None and idx is not None:
423
+ raise ValueError("Only one of img or idx must be provided.")
424
+ if idx is not None:
425
+ idx = idx if isinstance(idx, list) else [idx]
426
+ img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
427
+
428
+ return img if isinstance(img, list) else [img]
429
+
430
+ def ask_ai(self, query):
431
+ """
432
+ Ask AI a question.
433
+
434
+ Args:
435
+ query (str): Question to ask.
436
+
437
+ Returns:
438
+ (pandas.DataFrame): A dataframe containing filtered results to the SQL query.
439
+
440
+ Example:
441
+ ```python
442
+ exp = Explorer()
443
+ exp.create_embeddings_table()
444
+ answer = exp.ask_ai('Show images with 1 person and 2 dogs')
445
+ ```
446
+ """
447
+ result = prompt_sql_query(query)
448
+ try:
449
+ df = self.sql_query(result)
450
+ except Exception as e:
451
+ LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
452
+ LOGGER.error(e)
453
+ return None
454
+ return df
455
+
456
+ def visualize(self, result):
457
+ """
458
+ Visualize the results of a query. TODO.
459
+
460
+ Args:
461
+ result (pyarrow.Table): Table containing the results of a query.
462
+ """
463
+ pass
464
+
465
+ def generate_report(self, result):
466
+ """
467
+ Generate a report of the dataset.
468
+
469
+ TODO
470
+ """
471
+ pass
yolov8_model/ultralytics/data/explorer/gui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
yolov8_model/ultralytics/data/explorer/gui/dash.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import time
4
+ from threading import Thread
5
+
6
+ import pandas as pd
7
+
8
+ from ultralytics import Explorer
9
+ from ultralytics.utils import ROOT, SETTINGS
10
+ from ultralytics.utils.checks import check_requirements
11
+
12
+ check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3"))
13
+
14
+ import streamlit as st
15
+ from streamlit_select import image_select
16
+
17
+
18
+ def _get_explorer():
19
+ """Initializes and returns an instance of the Explorer class."""
20
+ exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
21
+ thread = Thread(
22
+ target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
23
+ )
24
+ thread.start()
25
+ progress_bar = st.progress(0, text="Creating embeddings table...")
26
+ while exp.progress < 1:
27
+ time.sleep(0.1)
28
+ progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
29
+ thread.join()
30
+ st.session_state["explorer"] = exp
31
+ progress_bar.empty()
32
+
33
+
34
+ def init_explorer_form():
35
+ """Initializes an Explorer instance and creates embeddings table with progress tracking."""
36
+ datasets = ROOT / "cfg" / "datasets"
37
+ ds = [d.name for d in datasets.glob("*.yaml")]
38
+ models = [
39
+ "yolov8n.pt",
40
+ "yolov8s.pt",
41
+ "yolov8m.pt",
42
+ "yolov8l.pt",
43
+ "yolov8x.pt",
44
+ "yolov8n-seg.pt",
45
+ "yolov8s-seg.pt",
46
+ "yolov8m-seg.pt",
47
+ "yolov8l-seg.pt",
48
+ "yolov8x-seg.pt",
49
+ "yolov8n-pose.pt",
50
+ "yolov8s-pose.pt",
51
+ "yolov8m-pose.pt",
52
+ "yolov8l-pose.pt",
53
+ "yolov8x-pose.pt",
54
+ ]
55
+ with st.form(key="explorer_init_form"):
56
+ col1, col2 = st.columns(2)
57
+ with col1:
58
+ st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
59
+ with col2:
60
+ st.selectbox("Select model", models, key="model")
61
+ st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
62
+
63
+ st.form_submit_button("Explore", on_click=_get_explorer)
64
+
65
+
66
+ def query_form():
67
+ """Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
68
+ with st.form("query_form"):
69
+ col1, col2 = st.columns([0.8, 0.2])
70
+ with col1:
71
+ st.text_input(
72
+ "Query",
73
+ "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
74
+ label_visibility="collapsed",
75
+ key="query",
76
+ )
77
+ with col2:
78
+ st.form_submit_button("Query", on_click=run_sql_query)
79
+
80
+
81
+ def ai_query_form():
82
+ """Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
83
+ with st.form("ai_query_form"):
84
+ col1, col2 = st.columns([0.8, 0.2])
85
+ with col1:
86
+ st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
87
+ with col2:
88
+ st.form_submit_button("Ask AI", on_click=run_ai_query)
89
+
90
+
91
+ def find_similar_imgs(imgs):
92
+ """Initializes a Streamlit form for AI-based image querying with custom input."""
93
+ exp = st.session_state["explorer"]
94
+ similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
95
+ paths = similar.to_pydict()["im_file"]
96
+ st.session_state["imgs"] = paths
97
+ st.session_state["res"] = similar
98
+
99
+
100
+ def similarity_form(selected_imgs):
101
+ """Initializes a form for AI-based image querying with custom input in Streamlit."""
102
+ st.write("Similarity Search")
103
+ with st.form("similarity_form"):
104
+ subcol1, subcol2 = st.columns([1, 1])
105
+ with subcol1:
106
+ st.number_input(
107
+ "limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
108
+ )
109
+
110
+ with subcol2:
111
+ disabled = not len(selected_imgs)
112
+ st.write("Selected: ", len(selected_imgs))
113
+ st.form_submit_button(
114
+ "Search",
115
+ disabled=disabled,
116
+ on_click=find_similar_imgs,
117
+ args=(selected_imgs,),
118
+ )
119
+ if disabled:
120
+ st.error("Select at least one image to search.")
121
+
122
+
123
+ # def persist_reset_form():
124
+ # with st.form("persist_reset"):
125
+ # col1, col2 = st.columns([1, 1])
126
+ # with col1:
127
+ # st.form_submit_button("Reset", on_click=reset)
128
+ #
129
+ # with col2:
130
+ # st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))
131
+
132
+
133
+ def run_sql_query():
134
+ """Executes an SQL query and returns the results."""
135
+ st.session_state["error"] = None
136
+ query = st.session_state.get("query")
137
+ if query.rstrip().lstrip():
138
+ exp = st.session_state["explorer"]
139
+ res = exp.sql_query(query, return_type="arrow")
140
+ st.session_state["imgs"] = res.to_pydict()["im_file"]
141
+ st.session_state["res"] = res
142
+
143
+
144
+ def run_ai_query():
145
+ """Execute SQL query and update session state with query results."""
146
+ if not SETTINGS["openai_api_key"]:
147
+ st.session_state[
148
+ "error"
149
+ ] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
150
+ return
151
+ st.session_state["error"] = None
152
+ query = st.session_state.get("ai_query")
153
+ if query.rstrip().lstrip():
154
+ exp = st.session_state["explorer"]
155
+ res = exp.ask_ai(query)
156
+ if not isinstance(res, pd.DataFrame) or res.empty:
157
+ st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
158
+ return
159
+ st.session_state["imgs"] = res["im_file"].to_list()
160
+ st.session_state["res"] = res
161
+
162
+
163
+ def reset_explorer():
164
+ """Resets the explorer to its initial state by clearing session variables."""
165
+ st.session_state["explorer"] = None
166
+ st.session_state["imgs"] = None
167
+ st.session_state["error"] = None
168
+
169
+
170
+ def utralytics_explorer_docs_callback():
171
+ """Resets the explorer to its initial state by clearing session variables."""
172
+ with st.container(border=True):
173
+ st.image(
174
+ "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
175
+ width=100,
176
+ )
177
+ st.markdown(
178
+ "<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
179
+ unsafe_allow_html=True,
180
+ help=None,
181
+ )
182
+ st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
183
+
184
+
185
+ def layout():
186
+ """Resets explorer session variables and provides documentation with a link to API docs."""
187
+ st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
188
+ st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
189
+
190
+ if st.session_state.get("explorer") is None:
191
+ init_explorer_form()
192
+ return
193
+
194
+ st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
195
+ exp = st.session_state.get("explorer")
196
+ col1, col2 = st.columns([0.75, 0.25], gap="small")
197
+ imgs = []
198
+ if st.session_state.get("error"):
199
+ st.error(st.session_state["error"])
200
+ else:
201
+ if st.session_state.get("imgs"):
202
+ imgs = st.session_state.get("imgs")
203
+ else:
204
+ imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
205
+ st.session_state["res"] = exp.table.to_arrow()
206
+ total_imgs, selected_imgs = len(imgs), []
207
+ with col1:
208
+ subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
209
+ with subcol1:
210
+ st.write("Max Images Displayed:")
211
+ with subcol2:
212
+ num = st.number_input(
213
+ "Max Images Displayed",
214
+ min_value=0,
215
+ max_value=total_imgs,
216
+ value=min(500, total_imgs),
217
+ key="num_imgs_displayed",
218
+ label_visibility="collapsed",
219
+ )
220
+ with subcol3:
221
+ st.write("Start Index:")
222
+ with subcol4:
223
+ start_idx = st.number_input(
224
+ "Start Index",
225
+ min_value=0,
226
+ max_value=total_imgs,
227
+ value=0,
228
+ key="start_index",
229
+ label_visibility="collapsed",
230
+ )
231
+ with subcol5:
232
+ reset = st.button("Reset", use_container_width=False, key="reset")
233
+ if reset:
234
+ st.session_state["imgs"] = None
235
+ st.experimental_rerun()
236
+
237
+ query_form()
238
+ ai_query_form()
239
+ if total_imgs:
240
+ labels, boxes, masks, kpts, classes = None, None, None, None, None
241
+ task = exp.model.task
242
+ if st.session_state.get("display_labels"):
243
+ labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num]
244
+ boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num]
245
+ masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num]
246
+ kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num]
247
+ classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num]
248
+ imgs_displayed = imgs[start_idx : start_idx + num]
249
+ selected_imgs = image_select(
250
+ f"Total samples: {total_imgs}",
251
+ images=imgs_displayed,
252
+ use_container_width=False,
253
+ # indices=[i for i in range(num)] if select_all else None,
254
+ labels=labels,
255
+ classes=classes,
256
+ bboxes=boxes,
257
+ masks=masks if task == "segment" else None,
258
+ kpts=kpts if task == "pose" else None,
259
+ )
260
+
261
+ with col2:
262
+ similarity_form(selected_imgs)
263
+ display_labels = st.checkbox("Labels", value=False, key="display_labels")
264
+ utralytics_explorer_docs_callback()
265
+
266
+
267
+ if __name__ == "__main__":
268
+ layout()
yolov8_model/ultralytics/data/explorer/utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import getpass
4
+ from typing import List
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from yolov8_model.ultralytics.data.augment import LetterBox
11
+ from yolov8_model.ultralytics.utils import LOGGER as logger
12
+ from yolov8_model.ultralytics.utils import SETTINGS
13
+ from yolov8_model.ultralytics.utils.checks import check_requirements
14
+ from yolov8_model.ultralytics.utils.ops import xyxy2xywh
15
+ from yolov8_model.ultralytics.utils.plotting import plot_images
16
+
17
+
18
+ def get_table_schema(vector_size):
19
+ """Extracts and returns the schema of a database table."""
20
+ from lancedb.pydantic import LanceModel, Vector
21
+
22
+ class Schema(LanceModel):
23
+ im_file: str
24
+ labels: List[str]
25
+ cls: List[int]
26
+ bboxes: List[List[float]]
27
+ masks: List[List[List[int]]]
28
+ keypoints: List[List[List[float]]]
29
+ vector: Vector(vector_size)
30
+
31
+ return Schema
32
+
33
+
34
+ def get_sim_index_schema():
35
+ """Returns a LanceModel schema for a database table with specified vector size."""
36
+ from lancedb.pydantic import LanceModel
37
+
38
+ class Schema(LanceModel):
39
+ idx: int
40
+ im_file: str
41
+ count: int
42
+ sim_im_files: List[str]
43
+
44
+ return Schema
45
+
46
+
47
+ def sanitize_batch(batch, dataset_info):
48
+ """Sanitizes input batch for inference, ensuring correct format and dimensions."""
49
+ batch["cls"] = batch["cls"].flatten().int().tolist()
50
+ box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
51
+ batch["bboxes"] = [box for box, _ in box_cls_pair]
52
+ batch["cls"] = [cls for _, cls in box_cls_pair]
53
+ batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
54
+ batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
55
+ batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
56
+ return batch
57
+
58
+
59
+ def plot_query_result(similar_set, plot_labels=True):
60
+ """
61
+ Plot images from the similar set.
62
+
63
+ Args:
64
+ similar_set (list): Pyarrow or pandas object containing the similar data points
65
+ plot_labels (bool): Whether to plot labels or not
66
+ """
67
+ similar_set = (
68
+ similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
69
+ )
70
+ empty_masks = [[[]]]
71
+ empty_boxes = [[]]
72
+ images = similar_set.get("im_file", [])
73
+ bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
74
+ masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
75
+ kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
76
+ cls = similar_set.get("cls", [])
77
+
78
+ plot_size = 640
79
+ imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
80
+ for i, imf in enumerate(images):
81
+ im = cv2.imread(imf)
82
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
83
+ h, w = im.shape[:2]
84
+ r = min(plot_size / h, plot_size / w)
85
+ imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
86
+ if plot_labels:
87
+ if len(bboxes) > i and len(bboxes[i]) > 0:
88
+ box = np.array(bboxes[i], dtype=np.float32)
89
+ box[:, [0, 2]] *= r
90
+ box[:, [1, 3]] *= r
91
+ plot_boxes.append(box)
92
+ if len(masks) > i and len(masks[i]) > 0:
93
+ mask = np.array(masks[i], dtype=np.uint8)[0]
94
+ plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
95
+ if len(kpts) > i and kpts[i] is not None:
96
+ kpt = np.array(kpts[i], dtype=np.float32)
97
+ kpt[:, :, :2] *= r
98
+ plot_kpts.append(kpt)
99
+ batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
100
+ imgs = np.stack(imgs, axis=0)
101
+ masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8)
102
+ kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32)
103
+ boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32)
104
+ batch_idx = np.concatenate(batch_idx, axis=0)
105
+ cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
106
+
107
+ return plot_images(
108
+ imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
109
+ )
110
+
111
+
112
+ def prompt_sql_query(query):
113
+ """Plots images with optional labels from a similar data set."""
114
+ check_requirements("openai>=1.6.1")
115
+ from openai import OpenAI
116
+
117
+ if not SETTINGS["openai_api_key"]:
118
+ logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
119
+ openai_api_key = getpass.getpass("OpenAI API key: ")
120
+ SETTINGS.update({"openai_api_key": openai_api_key})
121
+ openai = OpenAI(api_key=SETTINGS["openai_api_key"])
122
+
123
+ messages = [
124
+ {
125
+ "role": "system",
126
+ "content": """
127
+ You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
128
+ the following schema and a user request. You only need to output the format with fixed selection
129
+ statement that selects everything from "'table'", like `SELECT * from 'table'`
130
+
131
+ Schema:
132
+ im_file: string not null
133
+ labels: list<item: string> not null
134
+ child 0, item: string
135
+ cls: list<item: int64> not null
136
+ child 0, item: int64
137
+ bboxes: list<item: list<item: double>> not null
138
+ child 0, item: list<item: double>
139
+ child 0, item: double
140
+ masks: list<item: list<item: list<item: int64>>> not null
141
+ child 0, item: list<item: list<item: int64>>
142
+ child 0, item: list<item: int64>
143
+ child 0, item: int64
144
+ keypoints: list<item: list<item: list<item: double>>> not null
145
+ child 0, item: list<item: list<item: double>>
146
+ child 0, item: list<item: double>
147
+ child 0, item: double
148
+ vector: fixed_size_list<item: float>[256] not null
149
+ child 0, item: float
150
+
151
+ Some details about the schema:
152
+ - the "labels" column contains the string values like 'person' and 'dog' for the respective objects
153
+ in each image
154
+ - the "cls" column contains the integer values on these classes that map them the labels
155
+
156
+ Example of a correct query:
157
+ request - Get all data points that contain 2 or more people and at least one dog
158
+ correct query-
159
+ SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
160
+ """,
161
+ },
162
+ {"role": "user", "content": f"{query}"},
163
+ ]
164
+
165
+ response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
166
+ return response.choices[0].message.content
yolov8_model/ultralytics/data/loaders.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import glob
4
+ import math
5
+ import os
6
+ import time
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from threading import Thread
10
+ from urllib.parse import urlparse
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import requests
15
+ import torch
16
+ from PIL import Image
17
+
18
+ from yolov8_model.ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
19
+ from yolov8_model.ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
20
+ from yolov8_model.ultralytics.utils.checks import check_requirements
21
+
22
+
23
+ @dataclass
24
+ class SourceTypes:
25
+ """Class to represent various types of input sources for predictions."""
26
+
27
+ webcam: bool = False
28
+ screenshot: bool = False
29
+ from_img: bool = False
30
+ tensor: bool = False
31
+
32
+
33
+ class LoadStreams:
34
+ """
35
+ Stream Loader for various types of video streams.
36
+
37
+ Suitable for use with `yolo predict source='rtsp://example.com/media.mp4'`, supports RTSP, RTMP, HTTP, and TCP streams.
38
+
39
+ Attributes:
40
+ sources (str): The source input paths or URLs for the video streams.
41
+ vid_stride (int): Video frame-rate stride, defaults to 1.
42
+ buffer (bool): Whether to buffer input streams, defaults to False.
43
+ running (bool): Flag to indicate if the streaming thread is running.
44
+ mode (str): Set to 'stream' indicating real-time capture.
45
+ imgs (list): List of image frames for each stream.
46
+ fps (list): List of FPS for each stream.
47
+ frames (list): List of total frames for each stream.
48
+ threads (list): List of threads for each stream.
49
+ shape (list): List of shapes for each stream.
50
+ caps (list): List of cv2.VideoCapture objects for each stream.
51
+ bs (int): Batch size for processing.
52
+
53
+ Methods:
54
+ __init__: Initialize the stream loader.
55
+ update: Read stream frames in daemon thread.
56
+ close: Close stream loader and release resources.
57
+ __iter__: Returns an iterator object for the class.
58
+ __next__: Returns source paths, transformed, and original images for processing.
59
+ __len__: Return the length of the sources object.
60
+ """
61
+
62
+ def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
63
+ """Initialize instance variables and check for consistent input stream shapes."""
64
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
65
+ self.buffer = buffer # buffer input streams
66
+ self.running = True # running flag for Thread
67
+ self.mode = "stream"
68
+ self.vid_stride = vid_stride # video frame-rate stride
69
+
70
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
71
+ n = len(sources)
72
+ self.fps = [0] * n # frames per second
73
+ self.frames = [0] * n
74
+ self.threads = [None] * n
75
+ self.caps = [None] * n # video capture objects
76
+ self.imgs = [[] for _ in range(n)] # images
77
+ self.shape = [[] for _ in range(n)] # image shapes
78
+ self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
79
+ for i, s in enumerate(sources): # index, source
80
+ # Start thread to read frames from video stream
81
+ st = f"{i + 1}/{n}: {s}... "
82
+ if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
83
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
84
+ s = get_best_youtube_url(s)
85
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
86
+ if s == 0 and (is_colab() or is_kaggle()):
87
+ raise NotImplementedError(
88
+ "'source=0' webcam not supported in Colab and Kaggle notebooks. "
89
+ "Try running 'source=0' in a local environment."
90
+ )
91
+ self.caps[i] = cv2.VideoCapture(s) # store video capture object
92
+ if not self.caps[i].isOpened():
93
+ raise ConnectionError(f"{st}Failed to open {s}")
94
+ w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
95
+ h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
96
+ fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
97
+ self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
98
+ "inf"
99
+ ) # infinite stream fallback
100
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
101
+
102
+ success, im = self.caps[i].read() # guarantee first frame
103
+ if not success or im is None:
104
+ raise ConnectionError(f"{st}Failed to read images from {s}")
105
+ self.imgs[i].append(im)
106
+ self.shape[i] = im.shape
107
+ self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
108
+ LOGGER.info(f"{st}Success �� ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
109
+ self.threads[i].start()
110
+ LOGGER.info("") # newline
111
+
112
+ # Check for common shapes
113
+ self.bs = self.__len__()
114
+
115
+ def update(self, i, cap, stream):
116
+ """Read stream `i` frames in daemon thread."""
117
+ n, f = 0, self.frames[i] # frame number, frame array
118
+ while self.running and cap.isOpened() and n < (f - 1):
119
+ if len(self.imgs[i]) < 30: # keep a <=30-image buffer
120
+ n += 1
121
+ cap.grab() # .read() = .grab() followed by .retrieve()
122
+ if n % self.vid_stride == 0:
123
+ success, im = cap.retrieve()
124
+ if not success:
125
+ im = np.zeros(self.shape[i], dtype=np.uint8)
126
+ LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
127
+ cap.open(stream) # re-open stream if signal was lost
128
+ if self.buffer:
129
+ self.imgs[i].append(im)
130
+ else:
131
+ self.imgs[i] = [im]
132
+ else:
133
+ time.sleep(0.01) # wait until the buffer is empty
134
+
135
+ def close(self):
136
+ """Close stream loader and release resources."""
137
+ self.running = False # stop flag for Thread
138
+ for thread in self.threads:
139
+ if thread.is_alive():
140
+ thread.join(timeout=5) # Add timeout
141
+ for cap in self.caps: # Iterate through the stored VideoCapture objects
142
+ try:
143
+ cap.release() # release video capture
144
+ except Exception as e:
145
+ LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")
146
+ cv2.destroyAllWindows()
147
+
148
+ def __iter__(self):
149
+ """Iterates through YOLO image feed and re-opens unresponsive streams."""
150
+ self.count = -1
151
+ return self
152
+
153
+ def __next__(self):
154
+ """Returns source paths, transformed and original images for processing."""
155
+ self.count += 1
156
+
157
+ images = []
158
+ for i, x in enumerate(self.imgs):
159
+ # Wait until a frame is available in each buffer
160
+ while not x:
161
+ if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
162
+ self.close()
163
+ raise StopIteration
164
+ time.sleep(1 / min(self.fps))
165
+ x = self.imgs[i]
166
+ if not x:
167
+ LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")
168
+
169
+ # Get and remove the first frame from imgs buffer
170
+ if self.buffer:
171
+ images.append(x.pop(0))
172
+
173
+ # Get the last frame, and clear the rest from the imgs buffer
174
+ else:
175
+ images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
176
+ x.clear()
177
+
178
+ return self.sources, images, None, ""
179
+
180
+ def __len__(self):
181
+ """Return the length of the sources object."""
182
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
183
+
184
+
185
+ class LoadScreenshots:
186
+ """
187
+ YOLOv8 screenshot dataloader.
188
+
189
+ This class manages the loading of screenshot images for processing with YOLOv8.
190
+ Suitable for use with `yolo predict source=screen`.
191
+
192
+ Attributes:
193
+ source (str): The source input indicating which screen to capture.
194
+ screen (int): The screen number to capture.
195
+ left (int): The left coordinate for screen capture area.
196
+ top (int): The top coordinate for screen capture area.
197
+ width (int): The width of the screen capture area.
198
+ height (int): The height of the screen capture area.
199
+ mode (str): Set to 'stream' indicating real-time capture.
200
+ frame (int): Counter for captured frames.
201
+ sct (mss.mss): Screen capture object from `mss` library.
202
+ bs (int): Batch size, set to 1.
203
+ monitor (dict): Monitor configuration details.
204
+
205
+ Methods:
206
+ __iter__: Returns an iterator object.
207
+ __next__: Captures the next screenshot and returns it.
208
+ """
209
+
210
+ def __init__(self, source):
211
+ """Source = [screen_number left top width height] (pixels)."""
212
+ check_requirements("mss")
213
+ import mss # noqa
214
+
215
+ source, *params = source.split()
216
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
217
+ if len(params) == 1:
218
+ self.screen = int(params[0])
219
+ elif len(params) == 4:
220
+ left, top, width, height = (int(x) for x in params)
221
+ elif len(params) == 5:
222
+ self.screen, left, top, width, height = (int(x) for x in params)
223
+ self.mode = "stream"
224
+ self.frame = 0
225
+ self.sct = mss.mss()
226
+ self.bs = 1
227
+
228
+ # Parse monitor shape
229
+ monitor = self.sct.monitors[self.screen]
230
+ self.top = monitor["top"] if top is None else (monitor["top"] + top)
231
+ self.left = monitor["left"] if left is None else (monitor["left"] + left)
232
+ self.width = width or monitor["width"]
233
+ self.height = height or monitor["height"]
234
+ self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
235
+
236
+ def __iter__(self):
237
+ """Returns an iterator of the object."""
238
+ return self
239
+
240
+ def __next__(self):
241
+ """mss screen capture: get raw pixels from the screen as np array."""
242
+ im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
243
+ s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
244
+
245
+ self.frame += 1
246
+ return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
247
+
248
+
249
+ class LoadImages:
250
+ """
251
+ YOLOv8 image/video dataloader.
252
+
253
+ This class manages the loading and pre-processing of image and video data for YOLOv8. It supports loading from
254
+ various formats, including single image files, video files, and lists of image and video paths.
255
+
256
+ Attributes:
257
+ files (list): List of image and video file paths.
258
+ nf (int): Total number of files (images and videos).
259
+ video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
260
+ mode (str): Current mode, 'image' or 'video'.
261
+ vid_stride (int): Stride for video frame-rate, defaults to 1.
262
+ bs (int): Batch size, set to 1 for this class.
263
+ cap (cv2.VideoCapture): Video capture object for OpenCV.
264
+ frame (int): Frame counter for video.
265
+ frames (int): Total number of frames in the video.
266
+ count (int): Counter for iteration, initialized at 0 during `__iter__()`.
267
+
268
+ Methods:
269
+ _new_video(path): Create a new cv2.VideoCapture object for a given video path.
270
+ """
271
+
272
+ def __init__(self, path, vid_stride=1):
273
+ """Initialize the Dataloader and raise FileNotFoundError if file not found."""
274
+ parent = None
275
+ if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
276
+ parent = Path(path).parent
277
+ path = Path(path).read_text().splitlines() # list of sources
278
+ files = []
279
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
280
+ a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
281
+ if "*" in a:
282
+ files.extend(sorted(glob.glob(a, recursive=True))) # glob
283
+ elif os.path.isdir(a):
284
+ files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
285
+ elif os.path.isfile(a):
286
+ files.append(a) # files (absolute or relative to CWD)
287
+ elif parent and (parent / p).is_file():
288
+ files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
289
+ else:
290
+ raise FileNotFoundError(f"{p} does not exist")
291
+
292
+ images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
293
+ videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
294
+ ni, nv = len(images), len(videos)
295
+
296
+ self.files = images + videos
297
+ self.nf = ni + nv # number of files
298
+ self.video_flag = [False] * ni + [True] * nv
299
+ self.mode = "image"
300
+ self.vid_stride = vid_stride # video frame-rate stride
301
+ self.bs = 1
302
+ if any(videos):
303
+ self._new_video(videos[0]) # new video
304
+ else:
305
+ self.cap = None
306
+ if self.nf == 0:
307
+ raise FileNotFoundError(
308
+ f"No images or videos found in {p}. "
309
+ f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
310
+ )
311
+
312
+ def __iter__(self):
313
+ """Returns an iterator object for VideoStream or ImageFolder."""
314
+ self.count = 0
315
+ return self
316
+
317
+ def __next__(self):
318
+ """Return next image, path and metadata from dataset."""
319
+ if self.count == self.nf:
320
+ raise StopIteration
321
+ path = self.files[self.count]
322
+
323
+ if self.video_flag[self.count]:
324
+ # Read video
325
+ self.mode = "video"
326
+ for _ in range(self.vid_stride):
327
+ self.cap.grab()
328
+ success, im0 = self.cap.retrieve()
329
+ while not success:
330
+ self.count += 1
331
+ self.cap.release()
332
+ if self.count == self.nf: # last video
333
+ raise StopIteration
334
+ path = self.files[self.count]
335
+ self._new_video(path)
336
+ success, im0 = self.cap.read()
337
+
338
+ self.frame += 1
339
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
340
+ s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
341
+
342
+ else:
343
+ # Read image
344
+ self.count += 1
345
+ im0 = cv2.imread(path) # BGR
346
+ if im0 is None:
347
+ raise FileNotFoundError(f"Image Not Found {path}")
348
+ s = f"image {self.count}/{self.nf} {path}: "
349
+
350
+ return [path], [im0], self.cap, s
351
+
352
+ def _new_video(self, path):
353
+ """Create a new video capture object."""
354
+ self.frame = 0
355
+ self.cap = cv2.VideoCapture(path)
356
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
357
+
358
+ def __len__(self):
359
+ """Returns the number of files in the object."""
360
+ return self.nf # number of files
361
+
362
+
363
+ class LoadPilAndNumpy:
364
+ """
365
+ Load images from PIL and Numpy arrays for batch processing.
366
+
367
+ This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats.
368
+ It performs basic validation and format conversion to ensure that the images are in the required format for
369
+ downstream processing.
370
+
371
+ Attributes:
372
+ paths (list): List of image paths or autogenerated filenames.
373
+ im0 (list): List of images stored as Numpy arrays.
374
+ mode (str): Type of data being processed, defaults to 'image'.
375
+ bs (int): Batch size, equivalent to the length of `im0`.
376
+ count (int): Counter for iteration, initialized at 0 during `__iter__()`.
377
+
378
+ Methods:
379
+ _single_check(im): Validate and format a single image to a Numpy array.
380
+ """
381
+
382
+ def __init__(self, im0):
383
+ """Initialize PIL and Numpy Dataloader."""
384
+ if not isinstance(im0, list):
385
+ im0 = [im0]
386
+ self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
387
+ self.im0 = [self._single_check(im) for im in im0]
388
+ self.mode = "image"
389
+ # Generate fake paths
390
+ self.bs = len(self.im0)
391
+
392
+ @staticmethod
393
+ def _single_check(im):
394
+ """Validate and format an image to numpy array."""
395
+ assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
396
+ if isinstance(im, Image.Image):
397
+ if im.mode != "RGB":
398
+ im = im.convert("RGB")
399
+ im = np.asarray(im)[:, :, ::-1]
400
+ im = np.ascontiguousarray(im) # contiguous
401
+ return im
402
+
403
+ def __len__(self):
404
+ """Returns the length of the 'im0' attribute."""
405
+ return len(self.im0)
406
+
407
+ def __next__(self):
408
+ """Returns batch paths, images, processed images, None, ''."""
409
+ if self.count == 1: # loop only once as it's batch inference
410
+ raise StopIteration
411
+ self.count += 1
412
+ return self.paths, self.im0, None, ""
413
+
414
+ def __iter__(self):
415
+ """Enables iteration for class LoadPilAndNumpy."""
416
+ self.count = 0
417
+ return self
418
+
419
+
420
+ class LoadTensor:
421
+ """
422
+ Load images from torch.Tensor data.
423
+
424
+ This class manages the loading and pre-processing of image data from PyTorch tensors for further processing.
425
+
426
+ Attributes:
427
+ im0 (torch.Tensor): The input tensor containing the image(s).
428
+ bs (int): Batch size, inferred from the shape of `im0`.
429
+ mode (str): Current mode, set to 'image'.
430
+ paths (list): List of image paths or filenames.
431
+ count (int): Counter for iteration, initialized at 0 during `__iter__()`.
432
+
433
+ Methods:
434
+ _single_check(im, stride): Validate and possibly modify the input tensor.
435
+ """
436
+
437
+ def __init__(self, im0) -> None:
438
+ """Initialize Tensor Dataloader."""
439
+ self.im0 = self._single_check(im0)
440
+ self.bs = self.im0.shape[0]
441
+ self.mode = "image"
442
+ self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
443
+
444
+ @staticmethod
445
+ def _single_check(im, stride=32):
446
+ """Validate and format an image to torch.Tensor."""
447
+ s = (
448
+ f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
449
+ f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
450
+ )
451
+ if len(im.shape) != 4:
452
+ if len(im.shape) != 3:
453
+ raise ValueError(s)
454
+ LOGGER.warning(s)
455
+ im = im.unsqueeze(0)
456
+ if im.shape[2] % stride or im.shape[3] % stride:
457
+ raise ValueError(s)
458
+ if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
459
+ LOGGER.warning(
460
+ f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
461
+ f"Dividing input by 255."
462
+ )
463
+ im = im.float() / 255.0
464
+
465
+ return im
466
+
467
+ def __iter__(self):
468
+ """Returns an iterator object."""
469
+ self.count = 0
470
+ return self
471
+
472
+ def __next__(self):
473
+ """Return next item in the iterator."""
474
+ if self.count == 1:
475
+ raise StopIteration
476
+ self.count += 1
477
+ return self.paths, self.im0, None, ""
478
+
479
+ def __len__(self):
480
+ """Returns the batch size."""
481
+ return self.bs
482
+
483
+
484
+ def autocast_list(source):
485
+ """Merges a list of source of different types into a list of numpy arrays or PIL images."""
486
+ files = []
487
+ for im in source:
488
+ if isinstance(im, (str, Path)): # filename or uri
489
+ files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
490
+ elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
491
+ files.append(im)
492
+ else:
493
+ raise TypeError(
494
+ f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
495
+ f"See https://docs.ultralytics.com/modes/predict for supported source types."
496
+ )
497
+
498
+ return files
499
+
500
+
501
+ LOADERS = LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots # tuple
502
+
503
+
504
+ def get_best_youtube_url(url, use_pafy=True):
505
+ """
506
+ Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
507
+
508
+ This function uses the pafy or yt_dlp library to extract the video info from YouTube. It then finds the highest
509
+ quality MP4 format that has video codec but no audio codec, and returns the URL of this video stream.
510
+
511
+ Args:
512
+ url (str): The URL of the YouTube video.
513
+ use_pafy (bool): Use the pafy package, default=True, otherwise use yt_dlp package.
514
+
515
+ Returns:
516
+ (str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
517
+ """
518
+ if use_pafy:
519
+ check_requirements(("pafy", "youtube_dl==2020.12.2"))
520
+ import pafy # noqa
521
+
522
+ return pafy.new(url).getbestvideo(preftype="mp4").url
523
+ else:
524
+ check_requirements("yt-dlp")
525
+ import yt_dlp
526
+
527
+ with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
528
+ info_dict = ydl.extract_info(url, download=False) # extract info
529
+ for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
530
+ # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
531
+ good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
532
+ if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
533
+ return f.get("url")
yolov8_model/ultralytics/data/scripts/download_weights.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download latest models from https://github.com/ultralytics/assets/releases
4
+ # Example usage: bash ultralytics/data/scripts/download_weights.sh
5
+ # parent
6
+ # └── weights
7
+ # ├── yolov8n.pt ← downloads here
8
+ # ├── yolov8s.pt
9
+ # └── ...
10
+
11
+ python - <<EOF
12
+ from ultralytics.utils.downloads import attempt_download_asset
13
+
14
+ assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')]
15
+ for x in assets:
16
+ attempt_download_asset(f'weights/{x}')
17
+
18
+ EOF
yolov8_model/ultralytics/data/scripts/get_coco.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download COCO 2017 dataset https://cocodataset.org
4
+ # Example usage: bash data/scripts/get_coco.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── coco ← downloads here
9
+
10
+ # Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments
11
+ if [ "$#" -gt 0 ]; then
12
+ for opt in "$@"; do
13
+ case "${opt}" in
14
+ --train) train=true ;;
15
+ --val) val=true ;;
16
+ --test) test=true ;;
17
+ --segments) segments=true ;;
18
+ --sama) sama=true ;;
19
+ esac
20
+ done
21
+ else
22
+ train=true
23
+ val=true
24
+ test=false
25
+ segments=false
26
+ sama=false
27
+ fi
28
+
29
+ # Download/unzip labels
30
+ d='../datasets' # unzip directory
31
+ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
32
+ if [ "$segments" == "true" ]; then
33
+ f='coco2017labels-segments.zip' # 169 MB
34
+ elif [ "$sama" == "true" ]; then
35
+ f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/
36
+ else
37
+ f='coco2017labels.zip' # 46 MB
38
+ fi
39
+ echo 'Downloading' $url$f ' ...'
40
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
41
+
42
+ # Download/unzip images
43
+ d='../datasets/coco/images' # unzip directory
44
+ url=http://images.cocodataset.org/zips/
45
+ if [ "$train" == "true" ]; then
46
+ f='train2017.zip' # 19G, 118k images
47
+ echo 'Downloading' $url$f '...'
48
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
49
+ fi
50
+ if [ "$val" == "true" ]; then
51
+ f='val2017.zip' # 1G, 5k images
52
+ echo 'Downloading' $url$f '...'
53
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
54
+ fi
55
+ if [ "$test" == "true" ]; then
56
+ f='test2017.zip' # 7G, 41k images (optional)
57
+ echo 'Downloading' $url$f '...'
58
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
59
+ fi
60
+ wait # finish background tasks
yolov8_model/ultralytics/data/scripts/get_coco128.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017)
4
+ # Example usage: bash data/scripts/get_coco128.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── coco128 ← downloads here
9
+
10
+ # Download/unzip images and labels
11
+ d='../datasets' # unzip directory
12
+ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
13
+ f='coco128.zip' # or 'coco128-segments.zip', 68 MB
14
+ echo 'Downloading' $url$f ' ...'
15
+ curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
16
+
17
+ wait # finish background tasks
yolov8_model/ultralytics/data/scripts/get_imagenet.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
3
+ # Download ILSVRC2012 ImageNet dataset https://image-net.org
4
+ # Example usage: bash data/scripts/get_imagenet.sh
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── imagenet ← downloads here
9
+
10
+ # Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val
11
+ if [ "$#" -gt 0 ]; then
12
+ for opt in "$@"; do
13
+ case "${opt}" in
14
+ --train) train=true ;;
15
+ --val) val=true ;;
16
+ esac
17
+ done
18
+ else
19
+ train=true
20
+ val=true
21
+ fi
22
+
23
+ # Make dir
24
+ d='../datasets/imagenet' # unzip directory
25
+ mkdir -p $d && cd $d
26
+
27
+ # Download/unzip train
28
+ if [ "$train" == "true" ]; then
29
+ wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images
30
+ mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
31
+ tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
32
+ find . -name "*.tar" | while read NAME; do
33
+ mkdir -p "${NAME%.tar}"
34
+ tar -xf "${NAME}" -C "${NAME%.tar}"
35
+ rm -f "${NAME}"
36
+ done
37
+ cd ..
38
+ fi
39
+
40
+ # Download/unzip val
41
+ if [ "$val" == "true" ]; then
42
+ wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images
43
+ mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar
44
+ wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs
45
+ fi
46
+
47
+ # Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail)
48
+ # rm train/n04266014/n04266014_10835.JPEG
49
+
50
+ # TFRecords (optional)
51
+ # wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt
yolov8_model/ultralytics/data/split_dota.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import itertools
4
+ from glob import glob
5
+ from math import ceil
6
+ from pathlib import Path
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ from yolov8_model.ultralytics.data.utils import exif_size, img2label_paths
14
+ from yolov8_model.ultralytics.utils.checks import check_requirements
15
+
16
+ check_requirements("shapely")
17
+ from shapely.geometry import Polygon
18
+
19
+
20
+ def bbox_iof(polygon1, bbox2, eps=1e-6):
21
+ """
22
+ Calculate iofs between bbox1 and bbox2.
23
+
24
+ Args:
25
+ polygon1 (np.ndarray): Polygon coordinates, (n, 8).
26
+ bbox2 (np.ndarray): Bounding boxes, (n ,4).
27
+ """
28
+ polygon1 = polygon1.reshape(-1, 4, 2)
29
+ lt_point = np.min(polygon1, axis=-2)
30
+ rb_point = np.max(polygon1, axis=-2)
31
+ bbox1 = np.concatenate([lt_point, rb_point], axis=-1)
32
+
33
+ lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])
34
+ rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])
35
+ wh = np.clip(rb - lt, 0, np.inf)
36
+ h_overlaps = wh[..., 0] * wh[..., 1]
37
+
38
+ l, t, r, b = (bbox2[..., i] for i in range(4))
39
+ polygon2 = np.stack([l, t, r, t, r, b, l, b], axis=-1).reshape(-1, 4, 2)
40
+
41
+ sg_polys1 = [Polygon(p) for p in polygon1]
42
+ sg_polys2 = [Polygon(p) for p in polygon2]
43
+ overlaps = np.zeros(h_overlaps.shape)
44
+ for p in zip(*np.nonzero(h_overlaps)):
45
+ overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area
46
+ unions = np.array([p.area for p in sg_polys1], dtype=np.float32)
47
+ unions = unions[..., None]
48
+
49
+ unions = np.clip(unions, eps, np.inf)
50
+ outputs = overlaps / unions
51
+ if outputs.ndim == 1:
52
+ outputs = outputs[..., None]
53
+ return outputs
54
+
55
+
56
+ def load_yolo_dota(data_root, split="train"):
57
+ """
58
+ Load DOTA dataset.
59
+
60
+ Args:
61
+ data_root (str): Data root.
62
+ split (str): The split data set, could be train or val.
63
+
64
+ Notes:
65
+ The directory structure assumed for the DOTA dataset:
66
+ - data_root
67
+ - images
68
+ - train
69
+ - val
70
+ - labels
71
+ - train
72
+ - val
73
+ """
74
+ assert split in ["train", "val"]
75
+ im_dir = Path(data_root) / "images" / split
76
+ assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
77
+ im_files = glob(str(Path(data_root) / "images" / split / "*"))
78
+ lb_files = img2label_paths(im_files)
79
+ annos = []
80
+ for im_file, lb_file in zip(im_files, lb_files):
81
+ w, h = exif_size(Image.open(im_file))
82
+ with open(lb_file) as f:
83
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
84
+ lb = np.array(lb, dtype=np.float32)
85
+ annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))
86
+ return annos
87
+
88
+
89
+ def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.01):
90
+ """
91
+ Get the coordinates of windows.
92
+
93
+ Args:
94
+ im_size (tuple): Original image size, (h, w).
95
+ crop_sizes (List(int)): Crop size of windows.
96
+ gaps (List(int)): Gap between crops.
97
+ im_rate_thr (float): Threshold of windows areas divided by image ares.
98
+ """
99
+ h, w = im_size
100
+ windows = []
101
+ for crop_size, gap in zip(crop_sizes, gaps):
102
+ assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
103
+ step = crop_size - gap
104
+
105
+ xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
106
+ xs = [step * i for i in range(xn)]
107
+ if len(xs) > 1 and xs[-1] + crop_size > w:
108
+ xs[-1] = w - crop_size
109
+
110
+ yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
111
+ ys = [step * i for i in range(yn)]
112
+ if len(ys) > 1 and ys[-1] + crop_size > h:
113
+ ys[-1] = h - crop_size
114
+
115
+ start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
116
+ stop = start + crop_size
117
+ windows.append(np.concatenate([start, stop], axis=1))
118
+ windows = np.concatenate(windows, axis=0)
119
+
120
+ im_in_wins = windows.copy()
121
+ im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
122
+ im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
123
+ im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
124
+ win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
125
+ im_rates = im_areas / win_areas
126
+ if not (im_rates > im_rate_thr).any():
127
+ max_rate = im_rates.max()
128
+ im_rates[abs(im_rates - max_rate) < eps] = 1
129
+ return windows[im_rates > im_rate_thr]
130
+
131
+
132
+ def get_window_obj(anno, windows, iof_thr=0.7):
133
+ """Get objects for each window."""
134
+ h, w = anno["ori_size"]
135
+ label = anno["label"]
136
+ if len(label):
137
+ label[:, 1::2] *= w
138
+ label[:, 2::2] *= h
139
+ iofs = bbox_iof(label[:, 1:], windows)
140
+ # Unnormalized and misaligned coordinates
141
+ return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
142
+ else:
143
+ return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
144
+
145
+
146
+ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
147
+ """
148
+ Crop images and save new labels.
149
+
150
+ Args:
151
+ anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
152
+ windows (list): A list of windows coordinates.
153
+ window_objs (list): A list of labels inside each window.
154
+ im_dir (str): The output directory path of images.
155
+ lb_dir (str): The output directory path of labels.
156
+
157
+ Notes:
158
+ The directory structure assumed for the DOTA dataset:
159
+ - data_root
160
+ - images
161
+ - train
162
+ - val
163
+ - labels
164
+ - train
165
+ - val
166
+ """
167
+ im = cv2.imread(anno["filepath"])
168
+ name = Path(anno["filepath"]).stem
169
+ for i, window in enumerate(windows):
170
+ x_start, y_start, x_stop, y_stop = window.tolist()
171
+ new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
172
+ patch_im = im[y_start:y_stop, x_start:x_stop]
173
+ ph, pw = patch_im.shape[:2]
174
+
175
+ cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im)
176
+ label = window_objs[i]
177
+ if len(label) == 0:
178
+ continue
179
+ label[:, 1::2] -= x_start
180
+ label[:, 2::2] -= y_start
181
+ label[:, 1::2] /= pw
182
+ label[:, 2::2] /= ph
183
+
184
+ with open(Path(lb_dir) / f"{new_name}.txt", "w") as f:
185
+ for lb in label:
186
+ formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
187
+ f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
188
+
189
+
190
+ def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]):
191
+ """
192
+ Split both images and labels.
193
+
194
+ Notes:
195
+ The directory structure assumed for the DOTA dataset:
196
+ - data_root
197
+ - images
198
+ - split
199
+ - labels
200
+ - split
201
+ and the output directory structure is:
202
+ - save_dir
203
+ - images
204
+ - split
205
+ - labels
206
+ - split
207
+ """
208
+ im_dir = Path(save_dir) / "images" / split
209
+ im_dir.mkdir(parents=True, exist_ok=True)
210
+ lb_dir = Path(save_dir) / "labels" / split
211
+ lb_dir.mkdir(parents=True, exist_ok=True)
212
+
213
+ annos = load_yolo_dota(data_root, split=split)
214
+ for anno in tqdm(annos, total=len(annos), desc=split):
215
+ windows = get_windows(anno["ori_size"], crop_sizes, gaps)
216
+ window_objs = get_window_obj(anno, windows)
217
+ crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
218
+
219
+
220
+ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
221
+ """
222
+ Split train and val set of DOTA.
223
+
224
+ Notes:
225
+ The directory structure assumed for the DOTA dataset:
226
+ - data_root
227
+ - images
228
+ - train
229
+ - val
230
+ - labels
231
+ - train
232
+ - val
233
+ and the output directory structure is:
234
+ - save_dir
235
+ - images
236
+ - train
237
+ - val
238
+ - labels
239
+ - train
240
+ - val
241
+ """
242
+ crop_sizes, gaps = [], []
243
+ for r in rates:
244
+ crop_sizes.append(int(crop_size / r))
245
+ gaps.append(int(gap / r))
246
+ for split in ["train", "val"]:
247
+ split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
248
+
249
+
250
+ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
251
+ """
252
+ Split test set of DOTA, labels are not included within this set.
253
+
254
+ Notes:
255
+ The directory structure assumed for the DOTA dataset:
256
+ - data_root
257
+ - images
258
+ - test
259
+ and the output directory structure is:
260
+ - save_dir
261
+ - images
262
+ - test
263
+ """
264
+ crop_sizes, gaps = [], []
265
+ for r in rates:
266
+ crop_sizes.append(int(crop_size / r))
267
+ gaps.append(int(gap / r))
268
+ save_dir = Path(save_dir) / "images" / "test"
269
+ save_dir.mkdir(parents=True, exist_ok=True)
270
+
271
+ im_dir = Path(data_root) / "images" / "test"
272
+ assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
273
+ im_files = glob(str(im_dir / "*"))
274
+ for im_file in tqdm(im_files, total=len(im_files), desc="test"):
275
+ w, h = exif_size(Image.open(im_file))
276
+ windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
277
+ im = cv2.imread(im_file)
278
+ name = Path(im_file).stem
279
+ for window in windows:
280
+ x_start, y_start, x_stop, y_stop = window.tolist()
281
+ new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
282
+ patch_im = im[y_start:y_stop, x_start:x_stop]
283
+ cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split")
288
+ split_test(data_root="DOTAv2", save_dir="DOTAv2-split")
yolov8_model/ultralytics/data/utils.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import contextlib
4
+ import hashlib
5
+ import json
6
+ import os
7
+ import random
8
+ import subprocess
9
+ import time
10
+ import zipfile
11
+ from multiprocessing.pool import ThreadPool
12
+ from pathlib import Path
13
+ from tarfile import is_tarfile
14
+
15
+ import cv2
16
+ import numpy as np
17
+ from PIL import Image, ImageOps
18
+
19
+ from yolov8_model.ultralytics.nn.autobackend import check_class_names
20
+ from yolov8_model.ultralytics.utils import (
21
+ DATASETS_DIR,
22
+ LOGGER,
23
+ NUM_THREADS,
24
+ ROOT,
25
+ SETTINGS_YAML,
26
+ TQDM,
27
+ clean_url,
28
+ colorstr,
29
+ emojis,
30
+ yaml_load,
31
+ yaml_save,
32
+ )
33
+ from yolov8_model.ultralytics.utils.checks import check_file, check_font, is_ascii
34
+ from yolov8_model.ultralytics.utils.downloads import download, safe_download, unzip_file
35
+ from yolov8_model.ultralytics.utils.ops import segments2boxes
36
+
37
+ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
38
+ IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
39
+ VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
40
+ PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
41
+
42
+
43
+ def img2label_paths(img_paths):
44
+ """Define label paths as a function of image paths."""
45
+ sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
46
+ return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
47
+
48
+
49
+ def get_hash(paths):
50
+ """Returns a single hash value of a list of paths (files or dirs)."""
51
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
52
+ h = hashlib.sha256(str(size).encode()) # hash sizes
53
+ h.update("".join(paths).encode()) # hash paths
54
+ return h.hexdigest() # return hash
55
+
56
+
57
+ def exif_size(img: Image.Image):
58
+ """Returns exif-corrected PIL size."""
59
+ s = img.size # (width, height)
60
+ if img.format == "JPEG": # only support JPEG images
61
+ with contextlib.suppress(Exception):
62
+ exif = img.getexif()
63
+ if exif:
64
+ rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
65
+ if rotation in [6, 8]: # rotation 270 or 90
66
+ s = s[1], s[0]
67
+ return s
68
+
69
+
70
+ def verify_image(args):
71
+ """Verify one image."""
72
+ (im_file, cls), prefix = args
73
+ # Number (found, corrupt), message
74
+ nf, nc, msg = 0, 0, ""
75
+ try:
76
+ im = Image.open(im_file)
77
+ im.verify() # PIL verify
78
+ shape = exif_size(im) # image size
79
+ shape = (shape[1], shape[0]) # hw
80
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
81
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
82
+ if im.format.lower() in ("jpg", "jpeg"):
83
+ with open(im_file, "rb") as f:
84
+ f.seek(-2, 2)
85
+ if f.read() != b"\xff\xd9": # corrupt JPEG
86
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
87
+ msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
88
+ nf = 1
89
+ except Exception as e:
90
+ nc = 1
91
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
92
+ return (im_file, cls), nf, nc, msg
93
+
94
+
95
+ def verify_image_label(args):
96
+ """Verify one image-label pair."""
97
+ im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
98
+ # Number (missing, found, empty, corrupt), message, segments, keypoints
99
+ nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
100
+ try:
101
+ # Verify images
102
+ im = Image.open(im_file)
103
+ im.verify() # PIL verify
104
+ shape = exif_size(im) # image size
105
+ shape = (shape[1], shape[0]) # hw
106
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
107
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
108
+ if im.format.lower() in ("jpg", "jpeg"):
109
+ with open(im_file, "rb") as f:
110
+ f.seek(-2, 2)
111
+ if f.read() != b"\xff\xd9": # corrupt JPEG
112
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
113
+ msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
114
+
115
+ # Verify labels
116
+ if os.path.isfile(lb_file):
117
+ nf = 1 # label found
118
+ with open(lb_file) as f:
119
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
120
+ if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
121
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
122
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
123
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
124
+ lb = np.array(lb, dtype=np.float32)
125
+ nl = len(lb)
126
+ if nl:
127
+ if keypoint:
128
+ assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
129
+ points = lb[:, 5:].reshape(-1, ndim)[:, :2]
130
+ else:
131
+ assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
132
+ points = lb[:, 1:]
133
+ assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
134
+ assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
135
+
136
+ # All labels
137
+ max_cls = lb[:, 0].max() # max label count
138
+ assert max_cls <= num_cls, (
139
+ f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
140
+ f"Possible class labels are 0-{num_cls - 1}"
141
+ )
142
+ _, i = np.unique(lb, axis=0, return_index=True)
143
+ if len(i) < nl: # duplicate row check
144
+ lb = lb[i] # remove duplicates
145
+ if segments:
146
+ segments = [segments[x] for x in i]
147
+ msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
148
+ else:
149
+ ne = 1 # label empty
150
+ lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
151
+ else:
152
+ nm = 1 # label missing
153
+ lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
154
+ if keypoint:
155
+ keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
156
+ if ndim == 2:
157
+ kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
158
+ keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
159
+ lb = lb[:, :5]
160
+ return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
161
+ except Exception as e:
162
+ nc = 1
163
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
164
+ return [None, None, None, None, None, nm, nf, ne, nc, msg]
165
+
166
+
167
+ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
168
+ """
169
+ Convert a list of polygons to a binary mask of the specified image size.
170
+
171
+ Args:
172
+ imgsz (tuple): The size of the image as (height, width).
173
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
174
+ N is the number of polygons, and M is the number of points such that M % 2 = 0.
175
+ color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
176
+ downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
177
+
178
+ Returns:
179
+ (np.ndarray): A binary mask of the specified image size with the polygons filled in.
180
+ """
181
+ mask = np.zeros(imgsz, dtype=np.uint8)
182
+ polygons = np.asarray(polygons, dtype=np.int32)
183
+ polygons = polygons.reshape((polygons.shape[0], -1, 2))
184
+ cv2.fillPoly(mask, polygons, color=color)
185
+ nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
186
+ # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
187
+ return cv2.resize(mask, (nw, nh))
188
+
189
+
190
+ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
191
+ """
192
+ Convert a list of polygons to a set of binary masks of the specified image size.
193
+
194
+ Args:
195
+ imgsz (tuple): The size of the image as (height, width).
196
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
197
+ N is the number of polygons, and M is the number of points such that M % 2 = 0.
198
+ color (int): The color value to fill in the polygons on the masks.
199
+ downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
200
+
201
+ Returns:
202
+ (np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
203
+ """
204
+ return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
205
+
206
+
207
+ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
208
+ """Return a (640, 640) overlap mask."""
209
+ masks = np.zeros(
210
+ (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
211
+ dtype=np.int32 if len(segments) > 255 else np.uint8,
212
+ )
213
+ areas = []
214
+ ms = []
215
+ for si in range(len(segments)):
216
+ mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
217
+ ms.append(mask)
218
+ areas.append(mask.sum())
219
+ areas = np.asarray(areas)
220
+ index = np.argsort(-areas)
221
+ ms = np.array(ms)[index]
222
+ for i in range(len(segments)):
223
+ mask = ms[i] * (i + 1)
224
+ masks = masks + mask
225
+ masks = np.clip(masks, a_min=0, a_max=i + 1)
226
+ return masks, index
227
+
228
+
229
+ def find_dataset_yaml(path: Path) -> Path:
230
+ """
231
+ Find and return the YAML file associated with a Detect, Segment or Pose dataset.
232
+
233
+ This function searches for a YAML file at the root level of the provided directory first, and if not found, it
234
+ performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
235
+ is raised if no YAML file is found or if multiple YAML files are found.
236
+
237
+ Args:
238
+ path (Path): The directory path to search for the YAML file.
239
+
240
+ Returns:
241
+ (Path): The path of the found YAML file.
242
+ """
243
+ files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
244
+ assert files, f"No YAML file found in '{path.resolve()}'"
245
+ if len(files) > 1:
246
+ files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
247
+ assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
248
+ return files[0]
249
+
250
+
251
+ def check_det_dataset(dataset, autodownload=True):
252
+ """
253
+ Download, verify, and/or unzip a dataset if not found locally.
254
+
255
+ This function checks the availability of a specified dataset, and if not found, it has the option to download and
256
+ unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
257
+ resolves paths related to the dataset.
258
+
259
+ Args:
260
+ dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
261
+ autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
262
+
263
+ Returns:
264
+ (dict): Parsed dataset information and paths.
265
+ """
266
+
267
+ file = check_file(dataset)
268
+
269
+ # Download (optional)
270
+ extract_dir = ""
271
+ if zipfile.is_zipfile(file) or is_tarfile(file):
272
+ new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
273
+ file = find_dataset_yaml(DATASETS_DIR / new_dir)
274
+ extract_dir, autodownload = file.parent, False
275
+
276
+ # Read YAML
277
+ data = yaml_load(file, append_filename=True) # dictionary
278
+
279
+ # Checks
280
+ for k in "train", "val":
281
+ if k not in data:
282
+ if k != "val" or "validation" not in data:
283
+ raise SyntaxError(
284
+ emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
285
+ )
286
+ LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
287
+ data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
288
+ if "names" not in data and "nc" not in data:
289
+ raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
290
+ if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
291
+ raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
292
+ if "names" not in data:
293
+ data["names"] = [f"class_{i}" for i in range(data["nc"])]
294
+ else:
295
+ data["nc"] = len(data["names"])
296
+
297
+ data["names"] = check_class_names(data["names"])
298
+
299
+ # Resolve paths
300
+ path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
301
+ if not path.is_absolute():
302
+ path = (DATASETS_DIR / path).resolve()
303
+
304
+ # Set paths
305
+ data["path"] = path # download scripts
306
+ for k in "train", "val", "test":
307
+ if data.get(k): # prepend path
308
+ if isinstance(data[k], str):
309
+ x = (path / data[k]).resolve()
310
+ if not x.exists() and data[k].startswith("../"):
311
+ x = (path / data[k][3:]).resolve()
312
+ data[k] = str(x)
313
+ else:
314
+ data[k] = [str((path / x).resolve()) for x in data[k]]
315
+
316
+ # Parse YAML
317
+ val, s = (data.get(x) for x in ("val", "download"))
318
+ if val:
319
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
320
+ if not all(x.exists() for x in val):
321
+ name = clean_url(dataset) # dataset name with URL auth stripped
322
+ m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
323
+ if s and autodownload:
324
+ LOGGER.warning(m)
325
+ else:
326
+ m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
327
+ raise FileNotFoundError(m)
328
+ t = time.time()
329
+ r = None # success
330
+ if s.startswith("http") and s.endswith(".zip"): # URL
331
+ safe_download(url=s, dir=DATASETS_DIR, delete=True)
332
+ elif s.startswith("bash "): # bash script
333
+ LOGGER.info(f"Running {s} ...")
334
+ r = os.system(s)
335
+ else: # python script
336
+ exec(s, {"yaml": data})
337
+ dt = f"({round(time.time() - t, 1)}s)"
338
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
339
+ LOGGER.info(f"Dataset download {s}\n")
340
+ check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
341
+
342
+ return data # dictionary
343
+
344
+
345
+ def check_cls_dataset(dataset, split=""):
346
+ """
347
+ Checks a classification dataset such as Imagenet.
348
+
349
+ This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
350
+ If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
351
+
352
+ Args:
353
+ dataset (str | Path): The name of the dataset.
354
+ split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
355
+
356
+ Returns:
357
+ (dict): A dictionary containing the following keys:
358
+ - 'train' (Path): The directory path containing the training set of the dataset.
359
+ - 'val' (Path): The directory path containing the validation set of the dataset.
360
+ - 'test' (Path): The directory path containing the test set of the dataset.
361
+ - 'nc' (int): The number of classes in the dataset.
362
+ - 'names' (dict): A dictionary of class names in the dataset.
363
+ """
364
+
365
+ # Download (optional if dataset=https://file.zip is passed directly)
366
+ if str(dataset).startswith(("http:/", "https:/")):
367
+ dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
368
+
369
+ dataset = Path(dataset)
370
+ data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
371
+ if not data_dir.is_dir():
372
+ LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
373
+ t = time.time()
374
+ if str(dataset) == "imagenet":
375
+ subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
376
+ else:
377
+ url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
378
+ download(url, dir=data_dir.parent)
379
+ s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
380
+ LOGGER.info(s)
381
+ train_set = data_dir / "train"
382
+ val_set = (
383
+ data_dir / "val"
384
+ if (data_dir / "val").exists()
385
+ else data_dir / "validation"
386
+ if (data_dir / "validation").exists()
387
+ else None
388
+ ) # data/test or data/val
389
+ test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
390
+ if split == "val" and not val_set:
391
+ LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
392
+ elif split == "test" and not test_set:
393
+ LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
394
+
395
+ nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
396
+ names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
397
+ names = dict(enumerate(sorted(names)))
398
+
399
+ # Print to console
400
+ for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
401
+ prefix = f'{colorstr(f"{k}:")} {v}...'
402
+ if v is None:
403
+ LOGGER.info(prefix)
404
+ else:
405
+ files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
406
+ nf = len(files) # number of files
407
+ nd = len({file.parent for file in files}) # number of directories
408
+ if nf == 0:
409
+ if k == "train":
410
+ raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
411
+ else:
412
+ LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
413
+ elif nd != nc:
414
+ LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
415
+ else:
416
+ LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
417
+
418
+ return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
419
+
420
+
421
+ class HUBDatasetStats:
422
+ """
423
+ A class for generating HUB dataset JSON and `-hub` dataset directory.
424
+
425
+ Args:
426
+ path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
427
+ task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
428
+ autodownload (bool): Attempt to download dataset if not found locally. Default is False.
429
+
430
+ Example:
431
+ Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
432
+ i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
433
+ ```python
434
+ from ultralytics.data.utils import HUBDatasetStats
435
+
436
+ stats = HUBDatasetStats('path/to/coco8.zip', task='detect') # detect dataset
437
+ stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment') # segment dataset
438
+ stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose') # pose dataset
439
+ stats = HUBDatasetStats('path/to/imagenet10.zip', task='classify') # classification dataset
440
+
441
+ stats.get_json(save=True)
442
+ stats.process_images()
443
+ ```
444
+ """
445
+
446
+ def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
447
+ """Initialize class."""
448
+ path = Path(path).resolve()
449
+ LOGGER.info(f"Starting HUB dataset checks for {path}....")
450
+
451
+ self.task = task # detect, segment, pose, classify
452
+ if self.task == "classify":
453
+ unzip_dir = unzip_file(path)
454
+ data = check_cls_dataset(unzip_dir)
455
+ data["path"] = unzip_dir
456
+ else: # detect, segment, pose
457
+ _, data_dir, yaml_path = self._unzip(Path(path))
458
+ try:
459
+ # Load YAML with checks
460
+ data = yaml_load(yaml_path)
461
+ data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
462
+ yaml_save(yaml_path, data)
463
+ data = check_det_dataset(yaml_path, autodownload) # dict
464
+ data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
465
+ except Exception as e:
466
+ raise Exception("error/HUB/dataset_stats/init") from e
467
+
468
+ self.hub_dir = Path(f'{data["path"]}-hub')
469
+ self.im_dir = self.hub_dir / "images"
470
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
471
+ self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
472
+ self.data = data
473
+
474
+ @staticmethod
475
+ def _unzip(path):
476
+ """Unzip data.zip."""
477
+ if not str(path).endswith(".zip"): # path is data.yaml
478
+ return False, None, path
479
+ unzip_dir = unzip_file(path, path=path.parent)
480
+ assert unzip_dir.is_dir(), (
481
+ f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
482
+ )
483
+ return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
484
+
485
+ def _hub_ops(self, f):
486
+ """Saves a compressed image for HUB previews."""
487
+ compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
488
+
489
+ def get_json(self, save=False, verbose=False):
490
+ """Return dataset JSON for Ultralytics HUB."""
491
+
492
+ def _round(labels):
493
+ """Update labels to integer class and 4 decimal place floats."""
494
+ if self.task == "detect":
495
+ coordinates = labels["bboxes"]
496
+ elif self.task == "segment":
497
+ coordinates = [x.flatten() for x in labels["segments"]]
498
+ elif self.task == "pose":
499
+ n = labels["keypoints"].shape[0]
500
+ coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
501
+ else:
502
+ raise ValueError("Undefined dataset task.")
503
+ zipped = zip(labels["cls"], coordinates)
504
+ return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
505
+
506
+ for split in "train", "val", "test":
507
+ self.stats[split] = None # predefine
508
+ path = self.data.get(split)
509
+
510
+ # Check split
511
+ if path is None: # no split
512
+ continue
513
+ files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
514
+ if not files: # no images
515
+ continue
516
+
517
+ # Get dataset statistics
518
+ if self.task == "classify":
519
+ from torchvision.datasets import ImageFolder
520
+
521
+ dataset = ImageFolder(self.data[split])
522
+
523
+ x = np.zeros(len(dataset.classes)).astype(int)
524
+ for im in dataset.imgs:
525
+ x[im[1]] += 1
526
+
527
+ self.stats[split] = {
528
+ "instance_stats": {"total": len(dataset), "per_class": x.tolist()},
529
+ "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
530
+ "labels": [{Path(k).name: v} for k, v in dataset.imgs],
531
+ }
532
+ else:
533
+ from yolov8_model.ultralytics.data import YOLODataset
534
+
535
+ dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
536
+ x = np.array(
537
+ [
538
+ np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
539
+ for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
540
+ ]
541
+ ) # shape(128x80)
542
+ self.stats[split] = {
543
+ "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
544
+ "image_stats": {
545
+ "total": len(dataset),
546
+ "unlabelled": int(np.all(x == 0, 1).sum()),
547
+ "per_class": (x > 0).sum(0).tolist(),
548
+ },
549
+ "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
550
+ }
551
+
552
+ # Save, print and return
553
+ if save:
554
+ stats_path = self.hub_dir / "stats.json"
555
+ LOGGER.info(f"Saving {stats_path.resolve()}...")
556
+ with open(stats_path, "w") as f:
557
+ json.dump(self.stats, f) # save stats.json
558
+ if verbose:
559
+ LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
560
+ return self.stats
561
+
562
+ def process_images(self):
563
+ """Compress images for Ultralytics HUB."""
564
+ from yolov8_model.ultralytics.data import YOLODataset # ClassificationDataset
565
+
566
+ for split in "train", "val", "test":
567
+ if self.data.get(split) is None:
568
+ continue
569
+ dataset = YOLODataset(img_path=self.data[split], data=self.data)
570
+ with ThreadPool(NUM_THREADS) as pool:
571
+ for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
572
+ pass
573
+ LOGGER.info(f"Done. All images saved to {self.im_dir}")
574
+ return self.im_dir
575
+
576
+
577
+ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
578
+ """
579
+ Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
580
+ Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
581
+ resized.
582
+
583
+ Args:
584
+ f (str): The path to the input image file.
585
+ f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
586
+ max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
587
+ quality (int, optional): The image compression quality as a percentage. Default is 50%.
588
+
589
+ Example:
590
+ ```python
591
+ from pathlib import Path
592
+ from ultralytics.data.utils import compress_one_image
593
+
594
+ for f in Path('path/to/dataset').rglob('*.jpg'):
595
+ compress_one_image(f)
596
+ ```
597
+ """
598
+
599
+ try: # use PIL
600
+ im = Image.open(f)
601
+ r = max_dim / max(im.height, im.width) # ratio
602
+ if r < 1.0: # image too large
603
+ im = im.resize((int(im.width * r), int(im.height * r)))
604
+ im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
605
+ except Exception as e: # use OpenCV
606
+ LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
607
+ im = cv2.imread(f)
608
+ im_height, im_width = im.shape[:2]
609
+ r = max_dim / max(im_height, im_width) # ratio
610
+ if r < 1.0: # image too large
611
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
612
+ cv2.imwrite(str(f_new or f), im)
613
+
614
+
615
+ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
616
+ """
617
+ Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
618
+
619
+ Args:
620
+ path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
621
+ weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
622
+ annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
623
+
624
+ Example:
625
+ ```python
626
+ from ultralytics.data.utils import autosplit
627
+
628
+ autosplit()
629
+ ```
630
+ """
631
+
632
+ path = Path(path) # images dir
633
+ files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
634
+ n = len(files) # number of files
635
+ random.seed(0) # for reproducibility
636
+ indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
637
+
638
+ txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
639
+ for x in txt:
640
+ if (path.parent / x).exists():
641
+ (path.parent / x).unlink() # remove existing
642
+
643
+ LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
644
+ for i, img in TQDM(zip(indices, files), total=n):
645
+ if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
646
+ with open(path.parent / txt[i], "a") as f:
647
+ f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
yolov8_model/ultralytics/engine/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
yolov8_model/ultralytics/engine/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (172 Bytes). View file
 
yolov8_model/ultralytics/engine/__pycache__/exporter.cpython-310.pyc ADDED
Binary file (38.5 kB). View file
 
yolov8_model/ultralytics/engine/__pycache__/model.cpython-310.pyc ADDED
Binary file (35.1 kB). View file
 
yolov8_model/ultralytics/engine/__pycache__/predictor.cpython-310.pyc ADDED
Binary file (14.7 kB). View file
 
yolov8_model/ultralytics/engine/__pycache__/results.cpython-310.pyc ADDED
Binary file (27.3 kB). View file
 
yolov8_model/ultralytics/engine/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (26.3 kB). View file
 
yolov8_model/ultralytics/engine/__pycache__/validator.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
yolov8_model/ultralytics/engine/exporter.py ADDED
@@ -0,0 +1,1099 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
4
+
5
+ Format | `format=argument` | Model
6
+ --- | --- | ---
7
+ PyTorch | - | yolov8n.pt
8
+ TorchScript | `torchscript` | yolov8n.torchscript
9
+ ONNX | `onnx` | yolov8n.onnx
10
+ OpenVINO | `openvino` | yolov8n_openvino_model/
11
+ TensorRT | `engine` | yolov8n.engine
12
+ CoreML | `coreml` | yolov8n.mlpackage
13
+ TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
14
+ TensorFlow GraphDef | `pb` | yolov8n.pb
15
+ TensorFlow Lite | `tflite` | yolov8n.tflite
16
+ TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
17
+ TensorFlow.js | `tfjs` | yolov8n_web_model/
18
+ PaddlePaddle | `paddle` | yolov8n_paddle_model/
19
+ ncnn | `ncnn` | yolov8n_ncnn_model/
20
+
21
+ Requirements:
22
+ $ pip install "ultralytics[export]"
23
+
24
+ Python:
25
+ from ultralytics import YOLO
26
+ model = YOLO('yolov8n.pt')
27
+ results = model.export(format='onnx')
28
+
29
+ CLI:
30
+ $ yolo mode=export model=yolov8n.pt format=onnx
31
+
32
+ Inference:
33
+ $ yolo predict model=yolov8n.pt # PyTorch
34
+ yolov8n.torchscript # TorchScript
35
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
36
+ yolov8n_openvino_model # OpenVINO
37
+ yolov8n.engine # TensorRT
38
+ yolov8n.mlpackage # CoreML (macOS-only)
39
+ yolov8n_saved_model # TensorFlow SavedModel
40
+ yolov8n.pb # TensorFlow GraphDef
41
+ yolov8n.tflite # TensorFlow Lite
42
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
43
+ yolov8n_paddle_model # PaddlePaddle
44
+
45
+ TensorFlow.js:
46
+ $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
47
+ $ npm install
48
+ $ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
49
+ $ npm start
50
+ """
51
+ import json
52
+ import os
53
+ import shutil
54
+ import subprocess
55
+ import time
56
+ import warnings
57
+ from copy import deepcopy
58
+ from datetime import datetime
59
+ from pathlib import Path
60
+
61
+ import numpy as np
62
+ import torch
63
+
64
+ from yolov8_model.ultralytics.cfg import get_cfg
65
+ from yolov8_model.ultralytics.data.dataset import YOLODataset
66
+ from yolov8_model.ultralytics.data.utils import check_det_dataset
67
+ from yolov8_model.ultralytics.nn.autobackend import check_class_names, default_class_names
68
+ from yolov8_model.ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
69
+ from yolov8_model.ultralytics.nn.tasks import DetectionModel, SegmentationModel
70
+ from yolov8_model.ultralytics.utils import (
71
+ ARM64,
72
+ DEFAULT_CFG,
73
+ LINUX,
74
+ LOGGER,
75
+ MACOS,
76
+ ROOT,
77
+ WINDOWS,
78
+ __version__,
79
+ callbacks,
80
+ colorstr,
81
+ get_default_args,
82
+ yaml_save,
83
+ )
84
+ from yolov8_model.ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
85
+ from yolov8_model.ultralytics.utils.downloads import attempt_download_asset, get_github_assets
86
+ from yolov8_model.ultralytics.utils.files import file_size, spaces_in_path
87
+ from yolov8_model.ultralytics.utils.ops import Profile
88
+ from yolov8_model.ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
89
+
90
+
91
+ def export_formats():
92
+ """YOLOv8 export formats."""
93
+ import pandas
94
+
95
+ x = [
96
+ ["PyTorch", "-", ".pt", True, True],
97
+ ["TorchScript", "torchscript", ".torchscript", True, True],
98
+ ["ONNX", "onnx", ".onnx", True, True],
99
+ ["OpenVINO", "openvino", "_openvino_model", True, False],
100
+ ["TensorRT", "engine", ".engine", False, True],
101
+ ["CoreML", "coreml", ".mlpackage", True, False],
102
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
103
+ ["TensorFlow GraphDef", "pb", ".pb", True, True],
104
+ ["TensorFlow Lite", "tflite", ".tflite", True, False],
105
+ ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
106
+ ["TensorFlow.js", "tfjs", "_web_model", True, False],
107
+ ["PaddlePaddle", "paddle", "_paddle_model", True, True],
108
+ ["ncnn", "ncnn", "_ncnn_model", True, True],
109
+ ]
110
+ return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
111
+
112
+
113
+ def gd_outputs(gd):
114
+ """TensorFlow GraphDef model output node names."""
115
+ name_list, input_list = [], []
116
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
117
+ name_list.append(node.name)
118
+ input_list.extend(node.input)
119
+ return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
120
+
121
+
122
+ def try_export(inner_func):
123
+ """YOLOv8 export decorator, i..e @try_export."""
124
+ inner_args = get_default_args(inner_func)
125
+
126
+ def outer_func(*args, **kwargs):
127
+ """Export a model."""
128
+ prefix = inner_args["prefix"]
129
+ try:
130
+ with Profile() as dt:
131
+ f, model = inner_func(*args, **kwargs)
132
+ LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
133
+ return f, model
134
+ except Exception as e:
135
+ LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
136
+ raise e
137
+
138
+ return outer_func
139
+
140
+
141
+ class Exporter:
142
+ """
143
+ A class for exporting a model.
144
+
145
+ Attributes:
146
+ args (SimpleNamespace): Configuration for the exporter.
147
+ callbacks (list, optional): List of callback functions. Defaults to None.
148
+ """
149
+
150
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
151
+ """
152
+ Initializes the Exporter class.
153
+
154
+ Args:
155
+ cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
156
+ overrides (dict, optional): Configuration overrides. Defaults to None.
157
+ _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
158
+ """
159
+ self.args = get_cfg(cfg, overrides)
160
+ if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
161
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
162
+
163
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
164
+ callbacks.add_integration_callbacks(self)
165
+
166
+ @smart_inference_mode()
167
+ def __call__(self, model=None):
168
+ """Returns list of exported files/dirs after running callbacks."""
169
+ self.run_callbacks("on_export_start")
170
+ t = time.time()
171
+ fmt = self.args.format.lower() # to lowercase
172
+ if fmt in ("tensorrt", "trt"): # 'engine' aliases
173
+ fmt = "engine"
174
+ if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
175
+ fmt = "coreml"
176
+ fmts = tuple(export_formats()["Argument"][1:]) # available export formats
177
+ flags = [x == fmt for x in fmts]
178
+ if sum(flags) != 1:
179
+ raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
180
+ jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
181
+
182
+ # Device
183
+ if fmt == "engine" and self.args.device is None:
184
+ LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
185
+ self.args.device = "0"
186
+ self.device = select_device("cpu" if self.args.device is None else self.args.device)
187
+
188
+ # Checks
189
+ if not hasattr(model, "names"):
190
+ model.names = default_class_names()
191
+ model.names = check_class_names(model.names)
192
+ if self.args.half and onnx and self.device.type == "cpu":
193
+ LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0")
194
+ self.args.half = False
195
+ assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
196
+ self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
197
+ if self.args.optimize:
198
+ assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
199
+ assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
200
+ if edgetpu and not LINUX:
201
+ raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
202
+
203
+ # Input
204
+ im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
205
+ file = Path(
206
+ getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
207
+ )
208
+ if file.suffix in {".yaml", ".yml"}:
209
+ file = Path(file.name)
210
+
211
+ # Update model
212
+ model = deepcopy(model).to(self.device)
213
+ for p in model.parameters():
214
+ p.requires_grad = False
215
+ model.eval()
216
+ model.float()
217
+ model = model.fuse()
218
+ for m in model.modules():
219
+ if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
220
+ m.dynamic = self.args.dynamic
221
+ m.export = True
222
+ m.format = self.args.format
223
+ elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
224
+ # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
225
+ m.forward = m.forward_split
226
+
227
+ y = None
228
+ for _ in range(2):
229
+ y = model(im) # dry runs
230
+ if self.args.half and onnx and self.device.type != "cpu":
231
+ im, model = im.half(), model.half() # to FP16
232
+
233
+ # Filter warnings
234
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
235
+ warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
236
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
237
+
238
+ # Assign
239
+ self.im = im
240
+ self.model = model
241
+ self.file = file
242
+ self.output_shape = (
243
+ tuple(y.shape)
244
+ if isinstance(y, torch.Tensor)
245
+ else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
246
+ )
247
+ self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
248
+ data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
249
+ description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}'
250
+ self.metadata = {
251
+ "description": description,
252
+ "author": "Ultralytics",
253
+ "license": "AGPL-3.0 https://ultralytics.com/license",
254
+ "date": datetime.now().isoformat(),
255
+ "version": __version__,
256
+ "stride": int(max(model.stride)),
257
+ "task": model.task,
258
+ "batch": self.args.batch,
259
+ "imgsz": self.imgsz,
260
+ "names": model.names,
261
+ } # model metadata
262
+ if model.task == "pose":
263
+ self.metadata["kpt_shape"] = model.model[-1].kpt_shape
264
+
265
+ LOGGER.info(
266
+ f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
267
+ f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)'
268
+ )
269
+
270
+ # Exports
271
+ f = [""] * len(fmts) # exported filenames
272
+ if jit or ncnn: # TorchScript
273
+ f[0], _ = self.export_torchscript()
274
+ if engine: # TensorRT required before ONNX
275
+ f[1], _ = self.export_engine()
276
+ if onnx or xml: # OpenVINO requires ONNX
277
+ f[2], _ = self.export_onnx()
278
+ if xml: # OpenVINO
279
+ f[3], _ = self.export_openvino()
280
+ if coreml: # CoreML
281
+ f[4], _ = self.export_coreml()
282
+ if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
283
+ self.args.int8 |= edgetpu
284
+ f[5], keras_model = self.export_saved_model()
285
+ if pb or tfjs: # pb prerequisite to tfjs
286
+ f[6], _ = self.export_pb(keras_model=keras_model)
287
+ if tflite:
288
+ f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
289
+ if edgetpu:
290
+ f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
291
+ if tfjs:
292
+ f[9], _ = self.export_tfjs()
293
+ if paddle: # PaddlePaddle
294
+ f[10], _ = self.export_paddle()
295
+ if ncnn: # ncnn
296
+ f[11], _ = self.export_ncnn()
297
+
298
+ # Finish
299
+ f = [str(x) for x in f if x] # filter out '' and None
300
+ if any(f):
301
+ f = str(Path(f[-1]))
302
+ square = self.imgsz[0] == self.imgsz[1]
303
+ s = (
304
+ ""
305
+ if square
306
+ else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
307
+ f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
308
+ )
309
+ imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
310
+ predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
311
+ q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
312
+ LOGGER.info(
313
+ f'\nExport complete ({time.time() - t:.1f}s)'
314
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
315
+ f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
316
+ f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
317
+ f'\nVisualize: https://netron.app'
318
+ )
319
+
320
+ self.run_callbacks("on_export_end")
321
+ return f # return list of exported files/dirs
322
+
323
+ @try_export
324
+ def export_torchscript(self, prefix=colorstr("TorchScript:")):
325
+ """YOLOv8 TorchScript model export."""
326
+ LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
327
+ f = self.file.with_suffix(".torchscript")
328
+
329
+ ts = torch.jit.trace(self.model, self.im, strict=False)
330
+ extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
331
+ if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
332
+ LOGGER.info(f"{prefix} optimizing for mobile...")
333
+ from torch.utils.mobile_optimizer import optimize_for_mobile
334
+
335
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
336
+ else:
337
+ ts.save(str(f), _extra_files=extra_files)
338
+ return f, None
339
+
340
+ @try_export
341
+ def export_onnx(self, prefix=colorstr("ONNX:")):
342
+ """YOLOv8 ONNX export."""
343
+ requirements = ["onnx>=1.12.0"]
344
+ if self.args.simplify:
345
+ requirements += ["onnxsim>=0.4.33", "onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime"]
346
+ check_requirements(requirements)
347
+ import onnx # noqa
348
+
349
+ opset_version = self.args.opset or get_latest_opset()
350
+ LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
351
+ f = str(self.file.with_suffix(".onnx"))
352
+
353
+ output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
354
+ dynamic = self.args.dynamic
355
+ if dynamic:
356
+ dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
357
+ if isinstance(self.model, SegmentationModel):
358
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
359
+ dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
360
+ elif isinstance(self.model, DetectionModel):
361
+ dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
362
+
363
+ torch.onnx.export(
364
+ self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
365
+ self.im.cpu() if dynamic else self.im,
366
+ f,
367
+ verbose=False,
368
+ opset_version=opset_version,
369
+ do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
370
+ input_names=["images"],
371
+ output_names=output_names,
372
+ dynamic_axes=dynamic or None,
373
+ )
374
+
375
+ # Checks
376
+ model_onnx = onnx.load(f) # load onnx model
377
+ # onnx.checker.check_model(model_onnx) # check onnx model
378
+
379
+ # Simplify
380
+ if self.args.simplify:
381
+ try:
382
+ import onnxsim
383
+
384
+ LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...")
385
+ # subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
386
+ model_onnx, check = onnxsim.simplify(model_onnx)
387
+ assert check, "Simplified ONNX model could not be validated"
388
+ except Exception as e:
389
+ LOGGER.info(f"{prefix} simplifier failure: {e}")
390
+
391
+ # Metadata
392
+ for k, v in self.metadata.items():
393
+ meta = model_onnx.metadata_props.add()
394
+ meta.key, meta.value = k, str(v)
395
+
396
+ onnx.save(model_onnx, f)
397
+ return f, model_onnx
398
+
399
+ @try_export
400
+ def export_openvino(self, prefix=colorstr("OpenVINO:")):
401
+ """YOLOv8 OpenVINO export."""
402
+ check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
403
+ import openvino.runtime as ov # noqa
404
+ from openvino.tools import mo # noqa
405
+
406
+ LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
407
+ f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
408
+ fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
409
+ f_onnx = self.file.with_suffix(".onnx")
410
+ f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
411
+ fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
412
+
413
+ def serialize(ov_model, file):
414
+ """Set RT info, serialize and save metadata YAML."""
415
+ ov_model.set_rt_info("YOLOv8", ["model_info", "model_type"])
416
+ ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
417
+ ov_model.set_rt_info(114, ["model_info", "pad_value"])
418
+ ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
419
+ ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
420
+ ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
421
+ if self.model.task != "classify":
422
+ ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
423
+
424
+ ov.serialize(ov_model, file) # save
425
+ yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
426
+
427
+ ov_model = mo.convert_model(
428
+ f_onnx, model_name=self.pretty_name, framework="onnx", compress_to_fp16=self.args.half
429
+ ) # export
430
+
431
+ if self.args.int8:
432
+ if not self.args.data:
433
+ self.args.data = DEFAULT_CFG.data or "coco128.yaml"
434
+ LOGGER.warning(
435
+ f"{prefix} WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
436
+ f"Using default 'data={self.args.data}'."
437
+ )
438
+ check_requirements("nncf>=2.5.0")
439
+ import nncf
440
+
441
+ def transform_fn(data_item):
442
+ """Quantization transform function."""
443
+ assert (
444
+ data_item["img"].dtype == torch.uint8
445
+ ), "Input image must be uint8 for the quantization preprocessing"
446
+ im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
447
+ return np.expand_dims(im, 0) if im.ndim == 3 else im
448
+
449
+ # Generate calibration data for integer quantization
450
+ LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
451
+ data = check_det_dataset(self.args.data)
452
+ dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
453
+ n = len(dataset)
454
+ if n < 300:
455
+ LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
456
+ quantization_dataset = nncf.Dataset(dataset, transform_fn)
457
+ ignored_scope = nncf.IgnoredScope(types=["Multiply", "Subtract", "Sigmoid"]) # ignore operation
458
+ quantized_ov_model = nncf.quantize(
459
+ ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope
460
+ )
461
+ serialize(quantized_ov_model, fq_ov)
462
+ return fq, None
463
+
464
+ serialize(ov_model, f_ov)
465
+ return f, None
466
+
467
+ @try_export
468
+ def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
469
+ """YOLOv8 Paddle export."""
470
+ check_requirements(("paddlepaddle", "x2paddle"))
471
+ import x2paddle # noqa
472
+ from x2paddle.convert import pytorch2paddle # noqa
473
+
474
+ LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
475
+ f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
476
+
477
+ pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
478
+ yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
479
+ return f, None
480
+
481
+ @try_export
482
+ def export_ncnn(self, prefix=colorstr("ncnn:")):
483
+ """
484
+ YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
485
+ """
486
+ check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
487
+ import ncnn # noqa
488
+
489
+ LOGGER.info(f"\n{prefix} starting export with ncnn {ncnn.__version__}...")
490
+ f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
491
+ f_ts = self.file.with_suffix(".torchscript")
492
+
493
+ name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
494
+ pnnx = name if name.is_file() else ROOT / name
495
+ if not pnnx.is_file():
496
+ LOGGER.warning(
497
+ f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from "
498
+ "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
499
+ f"or in {ROOT}. See PNNX repo for full installation instructions."
500
+ )
501
+ system = ["macos"] if MACOS else ["windows"] if WINDOWS else ["ubuntu", "linux"] # operating system
502
+ try:
503
+ _, assets = get_github_assets(repo="pnnx/pnnx", retry=True)
504
+ url = [x for x in assets if any(s in x for s in system)][0]
505
+ except Exception as e:
506
+ url = f"https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip"
507
+ LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}")
508
+ asset = attempt_download_asset(url, repo="pnnx/pnnx", release="latest")
509
+ if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability
510
+ unzip_dir = Path(asset).with_suffix("")
511
+ (unzip_dir / name).rename(pnnx) # move binary to ROOT
512
+ shutil.rmtree(unzip_dir) # delete unzip dir
513
+ Path(asset).unlink() # delete zip
514
+ pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
515
+
516
+ ncnn_args = [
517
+ f'ncnnparam={f / "model.ncnn.param"}',
518
+ f'ncnnbin={f / "model.ncnn.bin"}',
519
+ f'ncnnpy={f / "model_ncnn.py"}',
520
+ ]
521
+
522
+ pnnx_args = [
523
+ f'pnnxparam={f / "model.pnnx.param"}',
524
+ f'pnnxbin={f / "model.pnnx.bin"}',
525
+ f'pnnxpy={f / "model_pnnx.py"}',
526
+ f'pnnxonnx={f / "model.pnnx.onnx"}',
527
+ ]
528
+
529
+ cmd = [
530
+ str(pnnx),
531
+ str(f_ts),
532
+ *ncnn_args,
533
+ *pnnx_args,
534
+ f"fp16={int(self.args.half)}",
535
+ f"device={self.device.type}",
536
+ f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
537
+ ]
538
+ f.mkdir(exist_ok=True) # make ncnn_model directory
539
+ LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
540
+ subprocess.run(cmd, check=True)
541
+
542
+ # Remove debug files
543
+ pnnx_files = [x.split("=")[-1] for x in pnnx_args]
544
+ for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
545
+ Path(f_debug).unlink(missing_ok=True)
546
+
547
+ yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
548
+ return str(f), None
549
+
550
+ @try_export
551
+ def export_coreml(self, prefix=colorstr("CoreML:")):
552
+ """YOLOv8 CoreML export."""
553
+ mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
554
+ check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0")
555
+ import coremltools as ct # noqa
556
+
557
+ LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
558
+ assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux."
559
+ f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
560
+ if f.is_dir():
561
+ shutil.rmtree(f)
562
+
563
+ bias = [0.0, 0.0, 0.0]
564
+ scale = 1 / 255
565
+ classifier_config = None
566
+ if self.model.task == "classify":
567
+ classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
568
+ model = self.model
569
+ elif self.model.task == "detect":
570
+ model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
571
+ else:
572
+ if self.args.nms:
573
+ LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolov8n.pt'.")
574
+ # TODO CoreML Segment and Pose model pipelining
575
+ model = self.model
576
+
577
+ ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
578
+ ct_model = ct.convert(
579
+ ts,
580
+ inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
581
+ classifier_config=classifier_config,
582
+ convert_to="neuralnetwork" if mlmodel else "mlprogram",
583
+ )
584
+ bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
585
+ if bits < 32:
586
+ if "kmeans" in mode:
587
+ check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
588
+ if mlmodel:
589
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
590
+ elif bits == 8: # mlprogram already quantized to FP16
591
+ import coremltools.optimize.coreml as cto
592
+
593
+ op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
594
+ config = cto.OptimizationConfig(global_config=op_config)
595
+ ct_model = cto.palettize_weights(ct_model, config=config)
596
+ if self.args.nms and self.model.task == "detect":
597
+ if mlmodel:
598
+ import platform
599
+
600
+ # coremltools<=6.2 NMS export requires Python<3.11
601
+ check_version(platform.python_version(), "<3.11", name="Python ", hard=True)
602
+ weights_dir = None
603
+ else:
604
+ ct_model.save(str(f)) # save otherwise weights_dir does not exist
605
+ weights_dir = str(f / "Data/com.apple.CoreML/weights")
606
+ ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
607
+
608
+ m = self.metadata # metadata dict
609
+ ct_model.short_description = m.pop("description")
610
+ ct_model.author = m.pop("author")
611
+ ct_model.license = m.pop("license")
612
+ ct_model.version = m.pop("version")
613
+ ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
614
+ try:
615
+ ct_model.save(str(f)) # save *.mlpackage
616
+ except Exception as e:
617
+ LOGGER.warning(
618
+ f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
619
+ f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
620
+ )
621
+ f = f.with_suffix(".mlmodel")
622
+ ct_model.save(str(f))
623
+ return f, ct_model
624
+
625
+ @try_export
626
+ def export_engine(self, prefix=colorstr("TensorRT:")):
627
+ """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
628
+ assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
629
+ f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016
630
+
631
+ try:
632
+ import tensorrt as trt # noqa
633
+ except ImportError:
634
+ if LINUX:
635
+ check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
636
+ import tensorrt as trt # noqa
637
+
638
+ check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
639
+
640
+ self.args.simplify = True
641
+
642
+ LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
643
+ assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
644
+ f = self.file.with_suffix(".engine") # TensorRT engine file
645
+ logger = trt.Logger(trt.Logger.INFO)
646
+ if self.args.verbose:
647
+ logger.min_severity = trt.Logger.Severity.VERBOSE
648
+
649
+ builder = trt.Builder(logger)
650
+ config = builder.create_builder_config()
651
+ config.max_workspace_size = self.args.workspace * 1 << 30
652
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
653
+
654
+ flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
655
+ network = builder.create_network(flag)
656
+ parser = trt.OnnxParser(network, logger)
657
+ if not parser.parse_from_file(f_onnx):
658
+ raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
659
+
660
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
661
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
662
+ for inp in inputs:
663
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
664
+ for out in outputs:
665
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
666
+
667
+ if self.args.dynamic:
668
+ shape = self.im.shape
669
+ if shape[0] <= 1:
670
+ LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
671
+ profile = builder.create_optimization_profile()
672
+ for inp in inputs:
673
+ profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
674
+ config.add_optimization_profile(profile)
675
+
676
+ LOGGER.info(
677
+ f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
678
+ )
679
+ if builder.platform_has_fast_fp16 and self.args.half:
680
+ config.set_flag(trt.BuilderFlag.FP16)
681
+
682
+ del self.model
683
+ torch.cuda.empty_cache()
684
+
685
+ # Write file
686
+ with builder.build_engine(network, config) as engine, open(f, "wb") as t:
687
+ # Metadata
688
+ meta = json.dumps(self.metadata)
689
+ t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
690
+ t.write(meta.encode())
691
+ # Model
692
+ t.write(engine.serialize())
693
+
694
+ return f, None
695
+
696
+ @try_export
697
+ def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
698
+ """YOLOv8 TensorFlow SavedModel export."""
699
+ cuda = torch.cuda.is_available()
700
+ try:
701
+ import tensorflow as tf # noqa
702
+ except ImportError:
703
+ check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
704
+ import tensorflow as tf # noqa
705
+ check_requirements(
706
+ (
707
+ "onnx",
708
+ "onnx2tf>=1.15.4,<=1.17.5",
709
+ "sng4onnx>=1.0.1",
710
+ "onnxsim>=0.4.33",
711
+ "onnx_graphsurgeon>=0.3.26",
712
+ "tflite_support",
713
+ "onnxruntime-gpu" if cuda else "onnxruntime",
714
+ ),
715
+ cmds="--extra-index-url https://pypi.ngc.nvidia.com",
716
+ ) # onnx_graphsurgeon only on NVIDIA
717
+
718
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
719
+ check_version(
720
+ tf.__version__,
721
+ "<=2.13.1",
722
+ name="tensorflow",
723
+ verbose=True,
724
+ msg="https://github.com/ultralytics/ultralytics/issues/5161",
725
+ )
726
+ f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
727
+ if f.is_dir():
728
+ import shutil
729
+
730
+ shutil.rmtree(f) # delete output folder
731
+
732
+ # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
733
+ onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
734
+ if not onnx2tf_file.exists():
735
+ attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
736
+
737
+ # Export to ONNX
738
+ self.args.simplify = True
739
+ f_onnx, _ = self.export_onnx()
740
+
741
+ # Export to TF
742
+ tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
743
+ if self.args.int8:
744
+ verbosity = "--verbosity info"
745
+ if self.args.data:
746
+ # Generate calibration data for integer quantization
747
+ LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
748
+ data = check_det_dataset(self.args.data)
749
+ dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
750
+ images = []
751
+ for i, batch in enumerate(dataset):
752
+ if i >= 100: # maximum number of calibration images
753
+ break
754
+ im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
755
+ images.append(im)
756
+ f.mkdir()
757
+ images = torch.cat(images, 0).float()
758
+ # mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
759
+ # std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
760
+ np.save(str(tmp_file), images.numpy()) # BHWC
761
+ int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
762
+ else:
763
+ int8 = "-oiqt -qt per-tensor"
764
+ else:
765
+ verbosity = "--non_verbose"
766
+ int8 = ""
767
+
768
+ cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip()
769
+ LOGGER.info(f"{prefix} running '{cmd}'")
770
+ subprocess.run(cmd, shell=True)
771
+ yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
772
+
773
+ # Remove/rename TFLite models
774
+ if self.args.int8:
775
+ tmp_file.unlink(missing_ok=True)
776
+ for file in f.rglob("*_dynamic_range_quant.tflite"):
777
+ file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
778
+ for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
779
+ file.unlink() # delete extra fp16 activation TFLite files
780
+
781
+ # Add TFLite metadata
782
+ for file in f.rglob("*.tflite"):
783
+ f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
784
+
785
+ return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
786
+
787
+ @try_export
788
+ def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
789
+ """YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
790
+ import tensorflow as tf # noqa
791
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
792
+
793
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
794
+ f = self.file.with_suffix(".pb")
795
+
796
+ m = tf.function(lambda x: keras_model(x)) # full model
797
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
798
+ frozen_func = convert_variables_to_constants_v2(m)
799
+ frozen_func.graph.as_graph_def()
800
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
801
+ return f, None
802
+
803
+ @try_export
804
+ def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
805
+ """YOLOv8 TensorFlow Lite export."""
806
+ import tensorflow as tf # noqa
807
+
808
+ LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
809
+ saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
810
+ if self.args.int8:
811
+ f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
812
+ elif self.args.half:
813
+ f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
814
+ else:
815
+ f = saved_model / f"{self.file.stem}_float32.tflite"
816
+ return str(f), None
817
+
818
+ @try_export
819
+ def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
820
+ """YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
821
+ LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")
822
+
823
+ cmd = "edgetpu_compiler --version"
824
+ help_url = "https://coral.ai/docs/edgetpu/compiler/"
825
+ assert LINUX, f"export only supported on Linux. See {help_url}"
826
+ if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
827
+ LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
828
+ sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system
829
+ for c in (
830
+ "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
831
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
832
+ "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
833
+ "sudo apt-get update",
834
+ "sudo apt-get install edgetpu-compiler",
835
+ ):
836
+ subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True)
837
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
838
+
839
+ LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
840
+ f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
841
+
842
+ cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
843
+ LOGGER.info(f"{prefix} running '{cmd}'")
844
+ subprocess.run(cmd, shell=True)
845
+ self._add_tflite_metadata(f)
846
+ return f, None
847
+
848
+ @try_export
849
+ def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
850
+ """YOLOv8 TensorFlow.js export."""
851
+ # JAX bug requiring install constraints in https://github.com/google/jax/issues/18978
852
+ check_requirements(["jax<=0.4.21", "jaxlib<=0.4.21", "tensorflowjs"])
853
+ import tensorflow as tf
854
+ import tensorflowjs as tfjs # noqa
855
+
856
+ LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
857
+ f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
858
+ f_pb = str(self.file.with_suffix(".pb")) # *.pb path
859
+
860
+ gd = tf.Graph().as_graph_def() # TF GraphDef
861
+ with open(f_pb, "rb") as file:
862
+ gd.ParseFromString(file.read())
863
+ outputs = ",".join(gd_outputs(gd))
864
+ LOGGER.info(f"\n{prefix} output node names: {outputs}")
865
+
866
+ quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
867
+ with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
868
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
869
+ LOGGER.info(f"{prefix} running '{cmd}'")
870
+ subprocess.run(cmd, shell=True)
871
+
872
+ if " " in f:
873
+ LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
874
+
875
+ # f_json = Path(f) / 'model.json' # *.json path
876
+ # with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
877
+ # subst = re.sub(
878
+ # r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
879
+ # r'"Identity.?.?": {"name": "Identity.?.?"}, '
880
+ # r'"Identity.?.?": {"name": "Identity.?.?"}, '
881
+ # r'"Identity.?.?": {"name": "Identity.?.?"}}}',
882
+ # r'{"outputs": {"Identity": {"name": "Identity"}, '
883
+ # r'"Identity_1": {"name": "Identity_1"}, '
884
+ # r'"Identity_2": {"name": "Identity_2"}, '
885
+ # r'"Identity_3": {"name": "Identity_3"}}}',
886
+ # f_json.read_text(),
887
+ # )
888
+ # j.write(subst)
889
+ yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
890
+ return f, None
891
+
892
+ def _add_tflite_metadata(self, file):
893
+ """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
894
+ from tflite_support import flatbuffers # noqa
895
+ from tflite_support import metadata as _metadata # noqa
896
+ from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
897
+
898
+ # Create model info
899
+ model_meta = _metadata_fb.ModelMetadataT()
900
+ model_meta.name = self.metadata["description"]
901
+ model_meta.version = self.metadata["version"]
902
+ model_meta.author = self.metadata["author"]
903
+ model_meta.license = self.metadata["license"]
904
+
905
+ # Label file
906
+ tmp_file = Path(file).parent / "temp_meta.txt"
907
+ with open(tmp_file, "w") as f:
908
+ f.write(str(self.metadata))
909
+
910
+ label_file = _metadata_fb.AssociatedFileT()
911
+ label_file.name = tmp_file.name
912
+ label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
913
+
914
+ # Create input info
915
+ input_meta = _metadata_fb.TensorMetadataT()
916
+ input_meta.name = "image"
917
+ input_meta.description = "Input image to be detected."
918
+ input_meta.content = _metadata_fb.ContentT()
919
+ input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
920
+ input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
921
+ input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
922
+
923
+ # Create output info
924
+ output1 = _metadata_fb.TensorMetadataT()
925
+ output1.name = "output"
926
+ output1.description = "Coordinates of detected objects, class labels, and confidence score"
927
+ output1.associatedFiles = [label_file]
928
+ if self.model.task == "segment":
929
+ output2 = _metadata_fb.TensorMetadataT()
930
+ output2.name = "output"
931
+ output2.description = "Mask protos"
932
+ output2.associatedFiles = [label_file]
933
+
934
+ # Create subgraph info
935
+ subgraph = _metadata_fb.SubGraphMetadataT()
936
+ subgraph.inputTensorMetadata = [input_meta]
937
+ subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
938
+ model_meta.subgraphMetadata = [subgraph]
939
+
940
+ b = flatbuffers.Builder(0)
941
+ b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
942
+ metadata_buf = b.Output()
943
+
944
+ populator = _metadata.MetadataPopulator.with_model_file(str(file))
945
+ populator.load_metadata_buffer(metadata_buf)
946
+ populator.load_associated_files([str(tmp_file)])
947
+ populator.populate()
948
+ tmp_file.unlink()
949
+
950
+ def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
951
+ """YOLOv8 CoreML pipeline."""
952
+ import coremltools as ct # noqa
953
+
954
+ LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
955
+ _, _, h, w = list(self.im.shape) # BCHW
956
+
957
+ # Output shapes
958
+ spec = model.get_spec()
959
+ out0, out1 = iter(spec.description.output)
960
+ if MACOS:
961
+ from PIL import Image
962
+
963
+ img = Image.new("RGB", (w, h)) # w=192, h=320
964
+ out = model.predict({"image": img})
965
+ out0_shape = out[out0.name].shape # (3780, 80)
966
+ out1_shape = out[out1.name].shape # (3780, 4)
967
+ else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
968
+ out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
969
+ out1_shape = self.output_shape[2], 4 # (3780, 4)
970
+
971
+ # Checks
972
+ names = self.metadata["names"]
973
+ nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
974
+ _, nc = out0_shape # number of anchors, number of classes
975
+ # _, nc = out0.type.multiArrayType.shape
976
+ assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
977
+
978
+ # Define output shapes (missing)
979
+ out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
980
+ out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
981
+ # spec.neuralNetwork.preprocessing[0].featureName = '0'
982
+
983
+ # Flexible input shapes
984
+ # from coremltools.models.neural_network import flexible_shape_utils
985
+ # s = [] # shapes
986
+ # s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
987
+ # s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
988
+ # flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
989
+ # r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
990
+ # r.add_height_range((192, 640))
991
+ # r.add_width_range((192, 640))
992
+ # flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
993
+
994
+ # Print
995
+ # print(spec.description)
996
+
997
+ # Model from spec
998
+ model = ct.models.MLModel(spec, weights_dir=weights_dir)
999
+
1000
+ # 3. Create NMS protobuf
1001
+ nms_spec = ct.proto.Model_pb2.Model()
1002
+ nms_spec.specificationVersion = 5
1003
+ for i in range(2):
1004
+ decoder_output = model._spec.description.output[i].SerializeToString()
1005
+ nms_spec.description.input.add()
1006
+ nms_spec.description.input[i].ParseFromString(decoder_output)
1007
+ nms_spec.description.output.add()
1008
+ nms_spec.description.output[i].ParseFromString(decoder_output)
1009
+
1010
+ nms_spec.description.output[0].name = "confidence"
1011
+ nms_spec.description.output[1].name = "coordinates"
1012
+
1013
+ output_sizes = [nc, 4]
1014
+ for i in range(2):
1015
+ ma_type = nms_spec.description.output[i].type.multiArrayType
1016
+ ma_type.shapeRange.sizeRanges.add()
1017
+ ma_type.shapeRange.sizeRanges[0].lowerBound = 0
1018
+ ma_type.shapeRange.sizeRanges[0].upperBound = -1
1019
+ ma_type.shapeRange.sizeRanges.add()
1020
+ ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
1021
+ ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
1022
+ del ma_type.shape[:]
1023
+
1024
+ nms = nms_spec.nonMaximumSuppression
1025
+ nms.confidenceInputFeatureName = out0.name # 1x507x80
1026
+ nms.coordinatesInputFeatureName = out1.name # 1x507x4
1027
+ nms.confidenceOutputFeatureName = "confidence"
1028
+ nms.coordinatesOutputFeatureName = "coordinates"
1029
+ nms.iouThresholdInputFeatureName = "iouThreshold"
1030
+ nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
1031
+ nms.iouThreshold = 0.45
1032
+ nms.confidenceThreshold = 0.25
1033
+ nms.pickTop.perClass = True
1034
+ nms.stringClassLabels.vector.extend(names.values())
1035
+ nms_model = ct.models.MLModel(nms_spec)
1036
+
1037
+ # 4. Pipeline models together
1038
+ pipeline = ct.models.pipeline.Pipeline(
1039
+ input_features=[
1040
+ ("image", ct.models.datatypes.Array(3, ny, nx)),
1041
+ ("iouThreshold", ct.models.datatypes.Double()),
1042
+ ("confidenceThreshold", ct.models.datatypes.Double()),
1043
+ ],
1044
+ output_features=["confidence", "coordinates"],
1045
+ )
1046
+ pipeline.add_model(model)
1047
+ pipeline.add_model(nms_model)
1048
+
1049
+ # Correct datatypes
1050
+ pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
1051
+ pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
1052
+ pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
1053
+
1054
+ # Update metadata
1055
+ pipeline.spec.specificationVersion = 5
1056
+ pipeline.spec.description.metadata.userDefined.update(
1057
+ {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
1058
+ )
1059
+
1060
+ # Save the model
1061
+ model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
1062
+ model.input_description["image"] = "Input image"
1063
+ model.input_description["iouThreshold"] = f"(optional) IOU threshold override (default: {nms.iouThreshold})"
1064
+ model.input_description[
1065
+ "confidenceThreshold"
1066
+ ] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
1067
+ model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
1068
+ model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
1069
+ LOGGER.info(f"{prefix} pipeline success")
1070
+ return model
1071
+
1072
+ def add_callback(self, event: str, callback):
1073
+ """Appends the given callback."""
1074
+ self.callbacks[event].append(callback)
1075
+
1076
+ def run_callbacks(self, event: str):
1077
+ """Execute all callbacks for a given event."""
1078
+ for callback in self.callbacks.get(event, []):
1079
+ callback(self)
1080
+
1081
+
1082
+ class IOSDetectModel(torch.nn.Module):
1083
+ """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
1084
+
1085
+ def __init__(self, model, im):
1086
+ """Initialize the IOSDetectModel class with a YOLO model and example image."""
1087
+ super().__init__()
1088
+ _, _, h, w = im.shape # batch, channel, height, width
1089
+ self.model = model
1090
+ self.nc = len(model.names) # number of classes
1091
+ if w == h:
1092
+ self.normalize = 1.0 / w # scalar
1093
+ else:
1094
+ self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
1095
+
1096
+ def forward(self, x):
1097
+ """Normalize predictions of object detection model with input size-dependent factors."""
1098
+ xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
1099
+ return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
yolov8_model/ultralytics/engine/model.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import torch
4
+ import inspect
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+ from yolov8_model.ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
10
+ from yolov8_model.ultralytics.hub.utils import HUB_WEB_ROOT
11
+ from yolov8_model.ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
12
+ from yolov8_model.ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
13
+
14
+
15
+ class Model(nn.Module):
16
+ """
17
+ A base class for implementing YOLO models, unifying APIs across different model types.
18
+
19
+ This class provides a common interface for various operations related to YOLO models, such as training,
20
+ validation, prediction, exporting, and benchmarking. It handles different types of models, including those
21
+ loaded from local files, Ultralytics HUB, or Triton Server. The class is designed to be flexible and
22
+ extendable for different tasks and model configurations.
23
+
24
+ Args:
25
+ model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file
26
+ path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
27
+ task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's
28
+ application domain, such as object detection, segmentation, etc. Defaults to None.
29
+ verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False.
30
+
31
+ Attributes:
32
+ callbacks (dict): A dictionary of callback functions for various events during model operations.
33
+ predictor (BasePredictor): The predictor object used for making predictions.
34
+ model (nn.Module): The underlying PyTorch model.
35
+ trainer (BaseTrainer): The trainer object used for training the model.
36
+ ckpt (dict): The checkpoint data if the model is loaded from a *.pt file.
37
+ cfg (str): The configuration of the model if loaded from a *.yaml file.
38
+ ckpt_path (str): The path to the checkpoint file.
39
+ overrides (dict): A dictionary of overrides for model configuration.
40
+ metrics (dict): The latest training/validation metrics.
41
+ session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
42
+ task (str): The type of task the model is intended for.
43
+ model_name (str): The name of the model.
44
+
45
+ Methods:
46
+ __call__: Alias for the predict method, enabling the model instance to be callable.
47
+ _new: Initializes a new model based on a configuration file.
48
+ _load: Loads a model from a checkpoint file.
49
+ _check_is_pytorch_model: Ensures that the model is a PyTorch model.
50
+ reset_weights: Resets the model's weights to their initial state.
51
+ load: Loads model weights from a specified file.
52
+ save: Saves the current state of the model to a file.
53
+ info: Logs or returns information about the model.
54
+ fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference.
55
+ predict: Performs object detection predictions.
56
+ track: Performs object tracking.
57
+ val: Validates the model on a dataset.
58
+ benchmark: Benchmarks the model on various export formats.
59
+ export: Exports the model to different formats.
60
+ train: Trains the model on a dataset.
61
+ tune: Performs hyperparameter tuning.
62
+ _apply: Applies a function to the model's tensors.
63
+ add_callback: Adds a callback function for an event.
64
+ clear_callback: Clears all callbacks for an event.
65
+ reset_callbacks: Resets all callbacks to their default functions.
66
+ _get_hub_session: Retrieves or creates an Ultralytics HUB session.
67
+ is_triton_model: Checks if a model is a Triton Server model.
68
+ is_hub_model: Checks if a model is an Ultralytics HUB model.
69
+ _reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
70
+ _smart_load: Loads the appropriate module based on the model task.
71
+ task_map: Provides a mapping from model tasks to corresponding classes.
72
+
73
+ Raises:
74
+ FileNotFoundError: If the specified model file does not exist or is inaccessible.
75
+ ValueError: If the model file or configuration is invalid or unsupported.
76
+ ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
77
+ TypeError: If the model is not a PyTorch model when required.
78
+ AttributeError: If required attributes or methods are not implemented or available.
79
+ NotImplementedError: If a specific model task or mode is not supported.
80
+ """
81
+
82
+ def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None:
83
+ """
84
+ Initializes a new instance of the YOLO model class.
85
+
86
+ This constructor sets up the model based on the provided model path or name. It handles various types of model
87
+ sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
88
+ important attributes of the model and prepares it for operations like training, prediction, or export.
89
+
90
+ Args:
91
+ model (Union[str, Path], optional): The path or model file to load or create. This can be a local
92
+ file path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
93
+ task (Any, optional): The task type associated with the YOLO model, specifying its application domain.
94
+ Defaults to None.
95
+ verbose (bool, optional): If True, enables verbose output during the model's initialization and subsequent
96
+ operations. Defaults to False.
97
+
98
+ Raises:
99
+ FileNotFoundError: If the specified model file does not exist or is inaccessible.
100
+ ValueError: If the model file or configuration is invalid or unsupported.
101
+ ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
102
+ """
103
+ super().__init__()
104
+ self.callbacks = callbacks.get_default_callbacks()
105
+ self.predictor = None # reuse predictor
106
+ self.model = None # model object
107
+ self.trainer = None # trainer object
108
+ self.ckpt = None # if loaded from *.pt
109
+ self.cfg = None # if loaded from *.yaml
110
+ self.ckpt_path = None
111
+ self.overrides = {} # overrides for trainer object
112
+ self.metrics = None # validation/training metrics
113
+ self.session = None # HUB session
114
+ self.task = task # task type
115
+ self.model_name = model = str(model).strip() # strip spaces
116
+
117
+ # Check if Ultralytics HUB model from https://hub.ultralytics.com
118
+ if self.is_hub_model(model):
119
+ # Fetch model from HUB
120
+ checks.check_requirements("hub-sdk>0.0.2")
121
+ self.session = self._get_hub_session(model)
122
+ model = self.session.model_file
123
+
124
+ # Check if Triton Server model
125
+ elif self.is_triton_model(model):
126
+ self.model = model
127
+ self.task = task
128
+ return
129
+
130
+ # Load or create new YOLO model
131
+ model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
132
+ if Path(model).suffix in (".yaml", ".yml"):
133
+ self._new(model, task=task)
134
+ else:
135
+ self._load(model, task=task)
136
+
137
+ self.model_name = model
138
+
139
+ def __call__(self, source=None, stream=False, **kwargs):
140
+ """
141
+ An alias for the predict method, enabling the model instance to be callable.
142
+
143
+ This method simplifies the process of making predictions by allowing the model instance to be called directly
144
+ with the required arguments for prediction.
145
+
146
+ Args:
147
+ source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
148
+ Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None.
149
+ stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
150
+ Defaults to False.
151
+ **kwargs (dict): Additional keyword arguments for configuring the prediction process.
152
+
153
+ Returns:
154
+ (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
155
+ """
156
+ return self.predict(source, stream, **kwargs)
157
+
158
+ @staticmethod
159
+ def _get_hub_session(model: str):
160
+ """Creates a session for Hub Training."""
161
+ from ultralytics.hub.session import HUBTrainingSession
162
+
163
+ session = HUBTrainingSession(model)
164
+ return session if session.client.authenticated else None
165
+
166
+ @staticmethod
167
+ def is_triton_model(model):
168
+ """Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
169
+ from urllib.parse import urlsplit
170
+
171
+ url = urlsplit(model)
172
+ return url.netloc and url.path and url.scheme in {"http", "grpc"}
173
+
174
+ @staticmethod
175
+ def is_hub_model(model):
176
+ """Check if the provided model is a HUB model."""
177
+ return any(
178
+ (
179
+ model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
180
+ [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
181
+ len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
182
+ )
183
+ )
184
+
185
+ def _new(self, cfg: str, task=None, model=None, verbose=True):
186
+ """
187
+ Initializes a new model and infers the task type from the model definitions.
188
+
189
+ Args:
190
+ cfg (str): model configuration file
191
+ task (str | None): model task
192
+ model (BaseModel): Customized model.
193
+ verbose (bool): display model info on load
194
+ """
195
+ cfg_dict = yaml_model_load(cfg)
196
+ self.cfg = cfg
197
+ self.task = task or guess_model_task(cfg_dict)
198
+ self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
199
+ self.overrides["model"] = self.cfg
200
+ self.overrides["task"] = self.task
201
+
202
+ # Below added to allow export from YAMLs
203
+ self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
204
+ self.model.task = self.task
205
+
206
+ def _load(self, weights: str, task=None):
207
+ """
208
+ Initializes a new model and infers the task type from the model head.
209
+
210
+ Args:
211
+ weights (str): model checkpoint to be loaded
212
+ task (str | None): model task
213
+ """
214
+ suffix = Path(weights).suffix
215
+ if suffix == ".pt":
216
+ self.model, self.ckpt = attempt_load_one_weight(weights)
217
+ self.task = self.model.args["task"]
218
+ self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
219
+ self.ckpt_path = self.model.pt_path
220
+ else:
221
+ weights = checks.check_file(weights)
222
+ self.model, self.ckpt = weights, None
223
+ self.task = task or guess_model_task(weights)
224
+ self.ckpt_path = weights
225
+ self.overrides["model"] = weights
226
+ self.overrides["task"] = self.task
227
+
228
+ def _check_is_pytorch_model(self):
229
+ """Raises TypeError is model is not a PyTorch model."""
230
+ pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
231
+ pt_module = isinstance(self.model, nn.Module)
232
+ if not (pt_module or pt_str):
233
+ raise TypeError(
234
+ f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
235
+ f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
236
+ f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
237
+ f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
238
+ f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
239
+ )
240
+
241
+ def reset_weights(self):
242
+ """
243
+ Resets the model parameters to randomly initialized values, effectively discarding all training information.
244
+
245
+ This method iterates through all modules in the model and resets their parameters if they have a
246
+ 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
247
+ to be updated during training.
248
+
249
+ Returns:
250
+ self (ultralytics.engine.model.Model): The instance of the class with reset weights.
251
+
252
+ Raises:
253
+ AssertionError: If the model is not a PyTorch model.
254
+ """
255
+ self._check_is_pytorch_model()
256
+ for m in self.model.modules():
257
+ if hasattr(m, "reset_parameters"):
258
+ m.reset_parameters()
259
+ for p in self.model.parameters():
260
+ p.requires_grad = True
261
+ return self
262
+
263
+ def load(self, weights="yolov8n.pt"):
264
+ """
265
+ Loads parameters from the specified weights file into the model.
266
+
267
+ This method supports loading weights from a file or directly from a weights object. It matches parameters by
268
+ name and shape and transfers them to the model.
269
+
270
+ Args:
271
+ weights (str | Path): Path to the weights file or a weights object. Defaults to 'yolov8n.pt'.
272
+
273
+ Returns:
274
+ self (ultralytics.engine.model.Model): The instance of the class with loaded weights.
275
+
276
+ Raises:
277
+ AssertionError: If the model is not a PyTorch model.
278
+ """
279
+ self._check_is_pytorch_model()
280
+ if isinstance(weights, (str, Path)):
281
+ weights, self.ckpt = attempt_load_one_weight(weights)
282
+ self.model.load(weights)
283
+ return self
284
+
285
+ def save(self, filename="model.pt"):
286
+ """
287
+ Saves the current model state to a file.
288
+
289
+ This method exports the model's checkpoint (ckpt) to the specified filename.
290
+
291
+ Args:
292
+ filename (str): The name of the file to save the model to. Defaults to 'model.pt'.
293
+
294
+ Raises:
295
+ AssertionError: If the model is not a PyTorch model.
296
+ """
297
+ self._check_is_pytorch_model()
298
+ import torch
299
+
300
+ torch.save(self.ckpt, filename)
301
+
302
+ def info(self, detailed=False, verbose=True):
303
+ """
304
+ Logs or returns model information.
305
+
306
+ This method provides an overview or detailed information about the model, depending on the arguments passed.
307
+ It can control the verbosity of the output.
308
+
309
+ Args:
310
+ detailed (bool): If True, shows detailed information about the model. Defaults to False.
311
+ verbose (bool): If True, prints the information. If False, returns the information. Defaults to True.
312
+
313
+ Returns:
314
+ (list): Various types of information about the model, depending on the 'detailed' and 'verbose' parameters.
315
+
316
+ Raises:
317
+ AssertionError: If the model is not a PyTorch model.
318
+ """
319
+ self._check_is_pytorch_model()
320
+ return self.model.info(detailed=detailed, verbose=verbose)
321
+
322
+ def fuse(self):
323
+ """
324
+ Fuses Conv2d and BatchNorm2d layers in the model.
325
+
326
+ This method optimizes the model by fusing Conv2d and BatchNorm2d layers, which can improve inference speed.
327
+
328
+ Raises:
329
+ AssertionError: If the model is not a PyTorch model.
330
+ """
331
+ self._check_is_pytorch_model()
332
+ self.model.fuse()
333
+
334
+ def embed(self, source=None, stream=False, **kwargs):
335
+ """
336
+ Generates image embeddings based on the provided source.
337
+
338
+ This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image source.
339
+ It allows customization of the embedding process through various keyword arguments.
340
+
341
+ Args:
342
+ source (str | int | PIL.Image | np.ndarray): The source of the image for generating embeddings.
343
+ The source can be a file path, URL, PIL image, numpy array, etc. Defaults to None.
344
+ stream (bool): If True, predictions are streamed. Defaults to False.
345
+ **kwargs (dict): Additional keyword arguments for configuring the embedding process.
346
+
347
+ Returns:
348
+ (List[torch.Tensor]): A list containing the image embeddings.
349
+
350
+ Raises:
351
+ AssertionError: If the model is not a PyTorch model.
352
+ """
353
+ if not kwargs.get("embed"):
354
+ kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
355
+ return self.predict(source, stream, **kwargs)
356
+
357
+ def predict(self, source=None, stream=False, predictor=None, **kwargs):
358
+ """
359
+ Performs predictions on the given image source using the YOLO model.
360
+
361
+ This method facilitates the prediction process, allowing various configurations through keyword arguments.
362
+ It supports predictions with custom predictors or the default predictor method. The method handles different
363
+ types of image sources and can operate in a streaming mode. It also provides support for SAM-type models
364
+ through 'prompts'.
365
+
366
+ The method sets up a new predictor if not already present and updates its arguments with each call.
367
+ It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
368
+ is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
369
+ for confidence threshold and saving behavior.
370
+
371
+ Args:
372
+ source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
373
+ Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS.
374
+ stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False.
375
+ predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
376
+ If None, the method uses a default predictor. Defaults to None.
377
+ **kwargs (dict): Additional keyword arguments for configuring the prediction process. These arguments allow
378
+ for further customization of the prediction behavior.
379
+
380
+ Returns:
381
+ (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
382
+
383
+ Raises:
384
+ AttributeError: If the predictor is not properly set up.
385
+ """
386
+ if source is None:
387
+ source = ASSETS
388
+ LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
389
+
390
+ is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
391
+ x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
392
+ )
393
+
394
+ custom = {"conf": 0.25, "save": is_cli, "mode": "predict"} # method defaults
395
+ args = {**self.overrides, **custom, **kwargs} # highest priority args on the right
396
+ prompts = args.pop("prompts", None) # for SAM-type models
397
+
398
+ if not self.predictor:
399
+ self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
400
+ self.predictor.setup_model(model=self.model, verbose=is_cli)
401
+ else: # only update args if predictor is already setup
402
+ self.predictor.args = get_cfg(self.predictor.args, args)
403
+ if "project" in args or "name" in args:
404
+ self.predictor.save_dir = get_save_dir(self.predictor.args)
405
+ if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
406
+ self.predictor.set_prompts(prompts)
407
+ return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
408
+
409
+ def track(self, source=None, stream=False, persist=False, **kwargs):
410
+ """
411
+ Conducts object tracking on the specified input source using the registered trackers.
412
+
413
+ This method performs object tracking using the model's predictors and optionally registered trackers. It is
414
+ capable of handling different types of input sources such as file paths or video streams. The method supports
415
+ customization of the tracking process through various keyword arguments. It registers trackers if they are not
416
+ already present and optionally persists them based on the 'persist' flag.
417
+
418
+ The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low
419
+ confidence predictions as input. The tracking mode is explicitly set in the keyword arguments.
420
+
421
+ Args:
422
+ source (str, optional): The input source for object tracking. It can be a file path, URL, or video stream.
423
+ stream (bool, optional): Treats the input source as a continuous video stream. Defaults to False.
424
+ persist (bool, optional): Persists the trackers between different calls to this method. Defaults to False.
425
+ **kwargs (dict): Additional keyword arguments for configuring the tracking process. These arguments allow
426
+ for further customization of the tracking behavior.
427
+
428
+ Returns:
429
+ (List[ultralytics.engine.results.Results]): A list of tracking results, encapsulated in the Results class.
430
+
431
+ Raises:
432
+ AttributeError: If the predictor does not have registered trackers.
433
+ """
434
+ if not hasattr(self.predictor, "trackers"):
435
+ from ultralytics.trackers import register_tracker
436
+
437
+ register_tracker(self, persist)
438
+ kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
439
+ kwargs["mode"] = "track"
440
+ return self.predict(source=source, stream=stream, **kwargs)
441
+
442
+ def val(self, validator=None, **kwargs):
443
+ """
444
+ Validates the model using a specified dataset and validation configuration.
445
+
446
+ This method facilitates the model validation process, allowing for a range of customization through various
447
+ settings and configurations. It supports validation with a custom validator or the default validation approach.
448
+ The method combines default configurations, method-specific defaults, and user-provided arguments to configure
449
+ the validation process. After validation, it updates the model's metrics with the results obtained from the
450
+ validator.
451
+
452
+ The method supports various arguments that allow customization of the validation process. For a comprehensive
453
+ list of all configurable options, users should refer to the 'configuration' section in the documentation.
454
+
455
+ Args:
456
+ validator (BaseValidator, optional): An instance of a custom validator class for validating the model. If
457
+ None, the method uses a default validator. Defaults to None.
458
+ **kwargs (dict): Arbitrary keyword arguments representing the validation configuration. These arguments are
459
+ used to customize various aspects of the validation process.
460
+
461
+ Returns:
462
+ (dict): Validation metrics obtained from the validation process.
463
+
464
+ Raises:
465
+ AssertionError: If the model is not a PyTorch model.
466
+ """
467
+ custom = {"rect": True} # method defaults
468
+ args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
469
+
470
+ validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
471
+ validator(model=self.model)
472
+ self.metrics = validator.metrics
473
+ return validator.metrics
474
+
475
+ def benchmark(self, **kwargs):
476
+ """
477
+ Benchmarks the model across various export formats to evaluate performance.
478
+
479
+ This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
480
+ It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured
481
+ using a combination of default configuration values, model-specific arguments, method-specific defaults, and
482
+ any additional user-provided keyword arguments.
483
+
484
+ The method supports various arguments that allow customization of the benchmarking process, such as dataset
485
+ choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all
486
+ configurable options, users should refer to the 'configuration' section in the documentation.
487
+
488
+ Args:
489
+ **kwargs (dict): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
490
+ default configurations, model-specific arguments, and method defaults.
491
+
492
+ Returns:
493
+ (dict): A dictionary containing the results of the benchmarking process.
494
+
495
+ Raises:
496
+ AssertionError: If the model is not a PyTorch model.
497
+ """
498
+ self._check_is_pytorch_model()
499
+ from ultralytics.utils.benchmarks import benchmark
500
+
501
+ custom = {"verbose": False} # method defaults
502
+ args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
503
+ return benchmark(
504
+ model=self,
505
+ data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
506
+ imgsz=args["imgsz"],
507
+ half=args["half"],
508
+ int8=args["int8"],
509
+ device=args["device"],
510
+ verbose=kwargs.get("verbose"),
511
+ )
512
+
513
+ def export(self, **kwargs):
514
+ """
515
+ Exports the model to a different format suitable for deployment.
516
+
517
+ This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
518
+ purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
519
+ defaults, and any additional arguments provided. The combined arguments are used to configure export settings.
520
+
521
+ The method supports a wide range of arguments to customize the export process. For a comprehensive list of all
522
+ possible arguments, refer to the 'configuration' section in the documentation.
523
+
524
+ Args:
525
+ **kwargs (dict): Arbitrary keyword arguments to customize the export process. These are combined with the
526
+ model's overrides and method defaults.
527
+
528
+ Returns:
529
+ (object): The exported model in the specified format, or an object related to the export process.
530
+
531
+ Raises:
532
+ AssertionError: If the model is not a PyTorch model.
533
+ """
534
+ self._check_is_pytorch_model()
535
+ from .exporter import Exporter
536
+
537
+ custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
538
+ args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
539
+ return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
540
+
541
+ def train(self, trainer=None, **kwargs):
542
+ """
543
+ Trains the model using the specified dataset and training configuration.
544
+
545
+ This method facilitates model training with a range of customizable settings and configurations. It supports
546
+ training with a custom trainer or the default training approach defined in the method. The method handles
547
+ different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
548
+ updating model and configuration after training.
549
+
550
+ When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
551
+ arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
552
+ configurations, method-specific defaults, and user-provided arguments to configure the training process. After
553
+ training, it updates the model and its configurations, and optionally attaches metrics.
554
+
555
+ Args:
556
+ trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
557
+ method uses a default trainer. Defaults to None.
558
+ **kwargs (dict): Arbitrary keyword arguments representing the training configuration. These arguments are
559
+ used to customize various aspects of the training process.
560
+
561
+ Returns:
562
+ (dict | None): Training metrics if available and training is successful; otherwise, None.
563
+
564
+ Raises:
565
+ AssertionError: If the model is not a PyTorch model.
566
+ PermissionError: If there is a permission issue with the HUB session.
567
+ ModuleNotFoundError: If the HUB SDK is not installed.
568
+ """
569
+ self._check_is_pytorch_model()
570
+ if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
571
+ if any(kwargs):
572
+ LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
573
+ kwargs = self.session.train_args # overwrite kwargs
574
+
575
+ checks.check_pip_update_available()
576
+
577
+ overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
578
+ custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
579
+ args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
580
+ # if args.get("resume"):
581
+ # args["resume"] = self.ckpt_path
582
+
583
+ self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
584
+ if not args.get("resume"): # manually set model only if not resuming
585
+ self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
586
+ self.model = self.trainer.model
587
+
588
+ if SETTINGS["hub"] is True and not self.session:
589
+ # Create a model in HUB
590
+ try:
591
+ self.session = self._get_hub_session(self.model_name)
592
+ if self.session:
593
+ self.session.create_model(args)
594
+ # Check model was created
595
+ if not getattr(self.session.model, "id", None):
596
+ self.session = None
597
+ except (PermissionError, ModuleNotFoundError):
598
+ # Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
599
+ pass
600
+
601
+ self.trainer.hub_session = self.session # attach optional HUB session
602
+ self.trainer.train()
603
+ # Update model and cfg after training
604
+ if RANK in (-1, 0):
605
+ ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
606
+ self.model, _ = attempt_load_one_weight(ckpt)
607
+ self.overrides = self.model.args
608
+ self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
609
+ return self.metrics
610
+
611
+ def tune(self, use_ray=False, iterations=10, *args, **kwargs):
612
+ """
613
+ Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
614
+
615
+ This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
616
+ When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.
617
+ Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and
618
+ custom arguments to configure the tuning process.
619
+
620
+ Args:
621
+ use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
622
+ iterations (int): The number of tuning iterations to perform. Defaults to 10.
623
+ *args (list): Variable length argument list for additional arguments.
624
+ **kwargs (dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
625
+
626
+ Returns:
627
+ (dict): A dictionary containing the results of the hyperparameter search.
628
+
629
+ Raises:
630
+ AssertionError: If the model is not a PyTorch model.
631
+ """
632
+ self._check_is_pytorch_model()
633
+ if use_ray:
634
+ from ultralytics.utils.tuner import run_ray_tune
635
+
636
+ return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
637
+ else:
638
+ from .tuner import Tuner
639
+
640
+ custom = {} # method defaults
641
+ args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
642
+ return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
643
+
644
+ def _apply(self, fn):
645
+ """Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
646
+ self._check_is_pytorch_model()
647
+ self = super()._apply(fn) # noqa
648
+ self.predictor = None # reset predictor as device may have changed
649
+ self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
650
+ return self
651
+
652
+ @property
653
+ def names(self):
654
+ """
655
+ Retrieves the class names associated with the loaded model.
656
+
657
+ This property returns the class names if they are defined in the model. It checks the class names for validity
658
+ using the 'check_class_names' function from the ultralytics.nn.autobackend module.
659
+
660
+ Returns:
661
+ (list | None): The class names of the model if available, otherwise None.
662
+ """
663
+ from ultralytics.nn.autobackend import check_class_names
664
+
665
+ return check_class_names(self.model.names) if hasattr(self.model, "names") else None
666
+
667
+ @property
668
+ def device(self):
669
+ """
670
+ Retrieves the device on which the model's parameters are allocated.
671
+
672
+ This property is used to determine whether the model's parameters are on CPU or GPU. It only applies to models
673
+ that are instances of nn.Module.
674
+
675
+ Returns:
676
+ (torch.device | None): The device (CPU/GPU) of the model if it is a PyTorch model, otherwise None.
677
+ """
678
+ return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
679
+
680
+ @property
681
+ def transforms(self):
682
+ """
683
+ Retrieves the transformations applied to the input data of the loaded model.
684
+
685
+ This property returns the transformations if they are defined in the model.
686
+
687
+ Returns:
688
+ (object | None): The transform object of the model if available, otherwise None.
689
+ """
690
+ return self.model.transforms if hasattr(self.model, "transforms") else None
691
+
692
+ def add_callback(self, event: str, func):
693
+ """
694
+ Adds a callback function for a specified event.
695
+
696
+ This method allows the user to register a custom callback function that is triggered on a specific event during
697
+ model training or inference.
698
+
699
+ Args:
700
+ event (str): The name of the event to attach the callback to.
701
+ func (callable): The callback function to be registered.
702
+
703
+ Raises:
704
+ ValueError: If the event name is not recognized.
705
+ """
706
+ self.callbacks[event].append(func)
707
+
708
+ def clear_callback(self, event: str):
709
+ """
710
+ Clears all callback functions registered for a specified event.
711
+
712
+ This method removes all custom and default callback functions associated with the given event.
713
+
714
+ Args:
715
+ event (str): The name of the event for which to clear the callbacks.
716
+
717
+ Raises:
718
+ ValueError: If the event name is not recognized.
719
+ """
720
+ self.callbacks[event] = []
721
+
722
+ def reset_callbacks(self):
723
+ """
724
+ Resets all callbacks to their default functions.
725
+
726
+ This method reinstates the default callback functions for all events, removing any custom callbacks that were
727
+ added previously.
728
+ """
729
+ for event in callbacks.default_callbacks.keys():
730
+ self.callbacks[event] = [callbacks.default_callbacks[event][0]]
731
+
732
+ @staticmethod
733
+ def _reset_ckpt_args(args):
734
+ """Reset arguments when loading a PyTorch model."""
735
+ include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
736
+ return {k: v for k, v in args.items() if k in include}
737
+
738
+ # def __getattr__(self, attr):
739
+ # """Raises error if object has no requested attribute."""
740
+ # name = self.__class__.__name__
741
+ # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
742
+
743
+ def _smart_load(self, key):
744
+ """Load model/trainer/validator/predictor."""
745
+ try:
746
+ return self.task_map[self.task][key]
747
+ except Exception as e:
748
+ name = self.__class__.__name__
749
+ mode = inspect.stack()[1][3] # get the function name.
750
+ raise NotImplementedError(
751
+ emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
752
+ ) from e
753
+
754
+ @property
755
+ def task_map(self):
756
+ """
757
+ Map head to model, trainer, validator, and predictor classes.
758
+
759
+ Returns:
760
+ task_map (dict): The map of model task to mode classes.
761
+ """
762
+ raise NotImplementedError("Please provide task map for your model!")
763
+
764
+ def profile(self, imgsz):
765
+ if type(imgsz) is int:
766
+ inputs = torch.randn((2, 3, imgsz, imgsz))
767
+ else:
768
+ inputs = torch.randn((2, 3, imgsz[0], imgsz[1]))
769
+ if next(self.model.parameters()).device.type == 'cuda':
770
+ return self.model.predict(inputs.to(torch.device('cuda')), profile=True)
771
+ else:
772
+ self.model.predict(inputs, profile=True)
yolov8_model/ultralytics/engine/predictor.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
4
+
5
+ Usage - sources:
6
+ $ yolo mode=predict model=yolov8n.pt source=0 # webcam
7
+ img.jpg # image
8
+ vid.mp4 # video
9
+ screen # screenshot
10
+ path/ # directory
11
+ list.txt # list of images
12
+ list.streams # list of streams
13
+ 'path/*.jpg' # glob
14
+ 'https://youtu.be/LNwODJXcvt4' # YouTube
15
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
16
+
17
+ Usage - formats:
18
+ $ yolo mode=predict model=yolov8n.pt # PyTorch
19
+ yolov8n.torchscript # TorchScript
20
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
21
+ yolov8n_openvino_model # OpenVINO
22
+ yolov8n.engine # TensorRT
23
+ yolov8n.mlpackage # CoreML (macOS-only)
24
+ yolov8n_saved_model # TensorFlow SavedModel
25
+ yolov8n.pb # TensorFlow GraphDef
26
+ yolov8n.tflite # TensorFlow Lite
27
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
28
+ yolov8n_paddle_model # PaddlePaddle
29
+ """
30
+ import platform
31
+ import threading
32
+ from pathlib import Path
33
+
34
+ import cv2
35
+ import numpy as np
36
+ import torch
37
+ from PIL import Image
38
+ from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
39
+ from yolov8_model.ultralytics.data import load_inference_source
40
+ from yolov8_model.ultralytics.data.augment import LetterBox, classify_transforms
41
+ from yolov8_model.ultralytics.nn.autobackend import AutoBackend
42
+ from yolov8_model.ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
43
+ from yolov8_model.ultralytics.utils.checks import check_imgsz, check_imshow
44
+ from yolov8_model.ultralytics.utils.files import increment_path
45
+ from yolov8_model.ultralytics.utils.torch_utils import select_device, smart_inference_mode
46
+
47
+ STREAM_WARNING = """
48
+ WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
49
+ errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
50
+
51
+ Example:
52
+ results = model(source=..., stream=True) # generator of Results objects
53
+ for r in results:
54
+ boxes = r.boxes # Boxes object for bbox outputs
55
+ masks = r.masks # Masks object for segment masks outputs
56
+ probs = r.probs # Class probabilities for classification outputs
57
+ """
58
+
59
+
60
+ class BasePredictor:
61
+ """
62
+ BasePredictor.
63
+
64
+ A base class for creating predictors.
65
+
66
+ Attributes:
67
+ args (SimpleNamespace): Configuration for the predictor.
68
+ save_dir (Path): Directory to save results.
69
+ done_warmup (bool): Whether the predictor has finished setup.
70
+ model (nn.Module): Model used for prediction.
71
+ data (dict): Data configuration.
72
+ device (torch.device): Device used for prediction.
73
+ dataset (Dataset): Dataset used for prediction.
74
+ vid_path (str): Path to video file.
75
+ vid_writer (cv2.VideoWriter): Video writer for saving video output.
76
+ data_path (str): Path to data.
77
+ """
78
+
79
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
80
+ """
81
+ Initializes the BasePredictor class.
82
+
83
+ Args:
84
+ cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
85
+ overrides (dict, optional): Configuration overrides. Defaults to None.
86
+ """
87
+ self.args = get_cfg(cfg, overrides)
88
+ self.save_dir = get_save_dir(self.args)
89
+ if self.args.conf is None:
90
+ self.args.conf = 0.25 # default conf=0.25
91
+ self.done_warmup = False
92
+ if self.args.show:
93
+ self.args.show = check_imshow(warn=True)
94
+
95
+ # Usable if setup is done
96
+ self.model = None
97
+ self.data = self.args.data # data_dict
98
+ self.imgsz = None
99
+ self.device = None
100
+ self.dataset = None
101
+ self.vid_path, self.vid_writer, self.vid_frame = None, None, None
102
+ self.plotted_img = None
103
+ self.data_path = None
104
+ self.source_type = None
105
+ self.batch = None
106
+ self.results = None
107
+ self.transforms = None
108
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
109
+ self.txt_path = None
110
+ self._lock = threading.Lock() # for automatic thread-safe inference
111
+ callbacks.add_integration_callbacks(self)
112
+
113
+ def preprocess(self, im):
114
+ """
115
+ Prepares input image before inference.
116
+
117
+ Args:
118
+ im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
119
+ """
120
+ not_tensor = not isinstance(im, torch.Tensor)
121
+ if not_tensor:
122
+ im = np.stack(self.pre_transform(im))
123
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
124
+ im = np.ascontiguousarray(im) # contiguous
125
+ im = torch.from_numpy(im)
126
+
127
+ im = im.to(self.device)
128
+ im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
129
+ if not_tensor:
130
+ im /= 255 # 0 - 255 to 0.0 - 1.0
131
+ return im
132
+
133
+ def inference(self, im, *args, **kwargs):
134
+ """Runs inference on a given image using the specified model and arguments."""
135
+ visualize = (
136
+ increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
137
+ if self.args.visualize and (not self.source_type.tensor)
138
+ else False
139
+ )
140
+ return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
141
+
142
+ def pre_transform(self, im):
143
+ """
144
+ Pre-transform input image before inference.
145
+
146
+ Args:
147
+ im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
148
+
149
+ Returns:
150
+ (list): A list of transformed images.
151
+ """
152
+ same_shapes = all(x.shape == im[0].shape for x in im)
153
+ letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
154
+ return [letterbox(image=x) for x in im]
155
+
156
+ def write_results(self, idx, results, batch):
157
+ """Write inference results to a file or directory."""
158
+ p, im, _ = batch
159
+ log_string = ""
160
+ if len(im.shape) == 3:
161
+ im = im[None] # expand for batch dim
162
+ if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
163
+ log_string += f"{idx}: "
164
+ frame = self.dataset.count
165
+ else:
166
+ frame = getattr(self.dataset, "frame", 0)
167
+ self.data_path = p
168
+ self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
169
+ log_string += "%gx%g " % im.shape[2:] # print string
170
+ result = results[idx]
171
+ log_string += result.verbose()
172
+
173
+ if self.args.save or self.args.show: # Add bbox to image
174
+ plot_args = {
175
+ "line_width": self.args.line_width,
176
+ "boxes": self.args.show_boxes,
177
+ "conf": self.args.show_conf,
178
+ "labels": self.args.show_labels,
179
+ }
180
+ if not self.args.retina_masks:
181
+ plot_args["im_gpu"] = im[idx]
182
+ self.plotted_img = result.plot(**plot_args)
183
+ # Write
184
+ if self.args.save_txt:
185
+ result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
186
+ if self.args.save_crop:
187
+ result.save_crop(
188
+ save_dir=self.save_dir / "crops",
189
+ file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
190
+ )
191
+
192
+ return log_string
193
+
194
+ def postprocess(self, preds, img, orig_imgs):
195
+ """Post-processes predictions for an image and returns them."""
196
+ return preds
197
+
198
+ def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
199
+ """Performs inference on an image or stream."""
200
+ self.stream = stream
201
+ if stream:
202
+ return self.stream_inference(source, model, *args, **kwargs)
203
+ else:
204
+ return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
205
+
206
+ def predict_cli(self, source=None, model=None):
207
+ """
208
+ Method used for CLI prediction.
209
+
210
+ It uses always generator as outputs as not required by CLI mode.
211
+ """
212
+ gen = self.stream_inference(source, model)
213
+ for _ in gen: # noqa, running CLI inference without accumulating any outputs (do not modify)
214
+ pass
215
+
216
+ def setup_source(self, source):
217
+ """Sets up source and inference mode."""
218
+ self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
219
+ self.transforms = (
220
+ getattr(
221
+ self.model.model,
222
+ "transforms",
223
+ classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
224
+ )
225
+ if self.args.task == "classify"
226
+ else None
227
+ )
228
+ self.dataset = load_inference_source(
229
+ source=source, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
230
+ )
231
+ self.source_type = self.dataset.source_type
232
+ if not getattr(self, "stream", True) and (
233
+ self.dataset.mode == "stream" # streams
234
+ or len(self.dataset) > 1000 # images
235
+ or any(getattr(self.dataset, "video_flag", [False]))
236
+ ): # videos
237
+ LOGGER.warning(STREAM_WARNING)
238
+ self.vid_path = [None] * self.dataset.bs
239
+ self.vid_writer = [None] * self.dataset.bs
240
+ self.vid_frame = [None] * self.dataset.bs
241
+
242
+ @smart_inference_mode()
243
+ def stream_inference(self, source=None, model=None, *args, **kwargs):
244
+ """Streams real-time inference on camera feed and saves results to file."""
245
+ if self.args.verbose:
246
+ LOGGER.info("")
247
+
248
+ # Setup model
249
+ if not self.model:
250
+ self.setup_model(model)
251
+
252
+ with self._lock: # for thread-safe inference
253
+ # Setup source every time predict is called
254
+ self.setup_source(source if source is not None else self.args.source)
255
+
256
+ # Check if save_dir/ label file exists
257
+ if self.args.save or self.args.save_txt:
258
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
259
+
260
+ # Warmup model
261
+ if not self.done_warmup:
262
+ self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
263
+ self.done_warmup = True
264
+
265
+ self.seen, self.windows, self.batch = 0, [], None
266
+ profilers = (
267
+ ops.Profile(device=self.device),
268
+ ops.Profile(device=self.device),
269
+ ops.Profile(device=self.device),
270
+ )
271
+ self.run_callbacks("on_predict_start")
272
+ all_results = []
273
+
274
+ for batch in self.dataset:
275
+ self.run_callbacks("on_predict_batch_start")
276
+ self.batch = batch
277
+ path, im0s, vid_cap, s = batch
278
+
279
+ # Preprocess
280
+ with profilers[0]:
281
+ im = self.preprocess(im0s)
282
+
283
+ # Inference
284
+ with profilers[1]:
285
+ preds = self.inference(im, *args, **kwargs)
286
+ if self.args.embed:
287
+ yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
288
+ continue
289
+
290
+ # Postprocess
291
+ with profilers[2]:
292
+ self.results = self.postprocess(preds, im, im0s)
293
+
294
+ self.run_callbacks("on_predict_postprocess_end")
295
+ # Visualize, save, write results
296
+ n = len(im0s)
297
+ for i in range(n):
298
+ self.seen += 1
299
+ self.results[i].speed = {
300
+ "preprocess": profilers[0].dt * 1e3 / n,
301
+ "inference": profilers[1].dt * 1e3 / n,
302
+ "postprocess": profilers[2].dt * 1e3 / n,
303
+ }
304
+ p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
305
+ p = Path(p)
306
+
307
+ if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
308
+ s += self.write_results(i, self.results, (p, im, im0))
309
+ if self.args.save or self.args.save_txt:
310
+ self.results[i].save_dir = self.save_dir.__str__()
311
+ if self.args.show and self.plotted_img is not None:
312
+ self.show(p)
313
+ if self.args.save and self.plotted_img is not None:
314
+ self.save_preds(vid_cap, i, str(self.save_dir / p.name))
315
+
316
+ self.run_callbacks("on_predict_batch_end")
317
+ yield from self.results
318
+ all_results.extend(self.results)
319
+ # Print time (inference-only)
320
+ if self.args.verbose:
321
+ LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
322
+
323
+ # Release assets
324
+ if isinstance(self.vid_writer[-1], cv2.VideoWriter):
325
+ self.vid_writer[-1].release() # release final video writer
326
+
327
+ # Print results
328
+ if self.args.verbose and self.seen:
329
+ t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
330
+ LOGGER.info(
331
+ f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
332
+ f"{(1, 3, *im.shape[2:])}" % t
333
+ )
334
+ if self.args.save or self.args.save_txt or self.args.save_crop:
335
+ nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
336
+ s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
337
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
338
+
339
+ self.run_callbacks("on_predict_end")
340
+ return all_results
341
+ def setup_model(self, model, verbose=True):
342
+ """Initialize YOLO model with given parameters and set it to evaluation mode."""
343
+ self.model = AutoBackend(
344
+ model or self.args.model,
345
+ device=select_device(self.args.device, verbose=verbose),
346
+ dnn=self.args.dnn,
347
+ data=self.args.data,
348
+ fp16=self.args.half,
349
+ fuse=True,
350
+ verbose=verbose,
351
+ )
352
+
353
+ self.device = self.model.device # update device
354
+ self.args.half = self.model.fp16 # update half
355
+ self.model.eval()
356
+
357
+ def show(self, p):
358
+ """Display an image in a window using OpenCV imshow()."""
359
+ im0 = self.plotted_img
360
+ if platform.system() == "Linux" and p not in self.windows:
361
+ self.windows.append(p)
362
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
363
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
364
+ cv2.imshow(str(p), im0)
365
+ cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond
366
+
367
+ def save_preds(self, vid_cap, idx, save_path):
368
+ """Save video predictions as mp4 at specified path."""
369
+ im0 = self.plotted_img
370
+ # Save imgs
371
+ if self.dataset.mode == "image":
372
+ cv2.imwrite(save_path, im0)
373
+ else: # 'video' or 'stream'
374
+ frames_path = f'{save_path.split(".", 1)[0]}_frames/'
375
+ if self.vid_path[idx] != save_path: # new video
376
+ self.vid_path[idx] = save_path
377
+ if self.args.save_frames:
378
+ Path(frames_path).mkdir(parents=True, exist_ok=True)
379
+ self.vid_frame[idx] = 0
380
+ if isinstance(self.vid_writer[idx], cv2.VideoWriter):
381
+ self.vid_writer[idx].release() # release previous video writer
382
+ if vid_cap: # video
383
+ fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
384
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
385
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
386
+ else: # stream
387
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
388
+ suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
389
+ self.vid_writer[idx] = cv2.VideoWriter(
390
+ str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)
391
+ )
392
+ # Write video
393
+ self.vid_writer[idx].write(im0)
394
+
395
+ # Write frame
396
+ if self.args.save_frames:
397
+ cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0)
398
+ self.vid_frame[idx] += 1
399
+
400
+ def run_callbacks(self, event: str):
401
+ """Runs all registered callbacks for a specific event."""
402
+ for callback in self.callbacks.get(event, []):
403
+ callback(self)
404
+
405
+ def add_callback(self, event: str, func):
406
+ """Add callback."""
407
+ self.callbacks[event].append(func)
yolov8_model/ultralytics/engine/results.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Ultralytics Results, Boxes and Masks classes for handling inference results.
4
+
5
+ Usage: See https://docs.ultralytics.com/modes/predict/
6
+ """
7
+
8
+ from copy import deepcopy
9
+ from functools import lru_cache
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from yolov8_model.ultralytics.data.augment import LetterBox
16
+ from yolov8_model.ultralytics.utils import LOGGER, SimpleClass, ops
17
+ from yolov8_model.ultralytics.utils.plotting import Annotator, colors, save_one_box
18
+ from yolov8_model.ultralytics.utils.torch_utils import smart_inference_mode
19
+
20
+
21
+ class BaseTensor(SimpleClass):
22
+ """Base tensor class with additional methods for easy manipulation and device handling."""
23
+
24
+ def __init__(self, data, orig_shape) -> None:
25
+ """
26
+ Initialize BaseTensor with data and original shape.
27
+
28
+ Args:
29
+ data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
30
+ orig_shape (tuple): Original shape of image.
31
+ """
32
+ assert isinstance(data, (torch.Tensor, np.ndarray))
33
+ self.data = data
34
+ self.orig_shape = orig_shape
35
+
36
+ @property
37
+ def shape(self):
38
+ """Return the shape of the data tensor."""
39
+ return self.data.shape
40
+
41
+ def cpu(self):
42
+ """Return a copy of the tensor on CPU memory."""
43
+ return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
44
+
45
+ def numpy(self):
46
+ """Return a copy of the tensor as a numpy array."""
47
+ return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
48
+
49
+ def cuda(self):
50
+ """Return a copy of the tensor on GPU memory."""
51
+ return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
52
+
53
+ def to(self, *args, **kwargs):
54
+ """Return a copy of the tensor with the specified device and dtype."""
55
+ return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
56
+
57
+ def __len__(self): # override len(results)
58
+ """Return the length of the data tensor."""
59
+ return len(self.data)
60
+
61
+ def __getitem__(self, idx):
62
+ """Return a BaseTensor with the specified index of the data tensor."""
63
+ return self.__class__(self.data[idx], self.orig_shape)
64
+
65
+
66
+ class Results(SimpleClass):
67
+ """
68
+ A class for storing and manipulating inference results.
69
+
70
+ Args:
71
+ orig_img (numpy.ndarray): The original image as a numpy array.
72
+ path (str): The path to the image file.
73
+ names (dict): A dictionary of class names.
74
+ boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
75
+ masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
76
+ probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
77
+ keypoints (List[List[float]], optional): A list of detected keypoints for each object.
78
+
79
+ Attributes:
80
+ orig_img (numpy.ndarray): The original image as a numpy array.
81
+ orig_shape (tuple): The original image shape in (height, width) format.
82
+ boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
83
+ masks (Masks, optional): A Masks object containing the detection masks.
84
+ probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
85
+ keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
86
+ speed (dict): A dictionary of preprocess, inference, and postprocess speeds in milliseconds per image.
87
+ names (dict): A dictionary of class names.
88
+ path (str): The path to the image file.
89
+ _keys (tuple): A tuple of attribute names for non-empty attributes.
90
+ """
91
+
92
+ def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None) -> None:
93
+ """Initialize the Results class."""
94
+ self.orig_img = orig_img
95
+ self.orig_shape = orig_img.shape[:2]
96
+ self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
97
+ self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
98
+ self.probs = Probs(probs) if probs is not None else None
99
+ self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
100
+ self.obb = OBB(obb, self.orig_shape) if obb is not None else None
101
+ self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
102
+ self.names = names
103
+ self.path = path
104
+ self.save_dir = None
105
+ self._keys = "boxes", "masks", "probs", "keypoints", "obb"
106
+
107
+ def __getitem__(self, idx):
108
+ """Return a Results object for the specified index."""
109
+ return self._apply("__getitem__", idx)
110
+
111
+ def __len__(self):
112
+ """Return the number of detections in the Results object."""
113
+ for k in self._keys:
114
+ v = getattr(self, k)
115
+ if v is not None:
116
+ return len(v)
117
+
118
+ def update(self, boxes=None, masks=None, probs=None, obb=None):
119
+ """Update the boxes, masks, and probs attributes of the Results object."""
120
+ if boxes is not None:
121
+ self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
122
+ if masks is not None:
123
+ self.masks = Masks(masks, self.orig_shape)
124
+ if probs is not None:
125
+ self.probs = probs
126
+ if obb is not None:
127
+ self.obb = OBB(obb, self.orig_shape)
128
+
129
+ def _apply(self, fn, *args, **kwargs):
130
+ """
131
+ Applies a function to all non-empty attributes and returns a new Results object with modified attributes. This
132
+ function is internally called by methods like .to(), .cuda(), .cpu(), etc.
133
+
134
+ Args:
135
+ fn (str): The name of the function to apply.
136
+ *args: Variable length argument list to pass to the function.
137
+ **kwargs: Arbitrary keyword arguments to pass to the function.
138
+
139
+ Returns:
140
+ Results: A new Results object with attributes modified by the applied function.
141
+ """
142
+ r = self.new()
143
+ for k in self._keys:
144
+ v = getattr(self, k)
145
+ if v is not None:
146
+ setattr(r, k, getattr(v, fn)(*args, **kwargs))
147
+ return r
148
+
149
+ def cpu(self):
150
+ """Return a copy of the Results object with all tensors on CPU memory."""
151
+ return self._apply("cpu")
152
+
153
+ def numpy(self):
154
+ """Return a copy of the Results object with all tensors as numpy arrays."""
155
+ return self._apply("numpy")
156
+
157
+ def cuda(self):
158
+ """Return a copy of the Results object with all tensors on GPU memory."""
159
+ return self._apply("cuda")
160
+
161
+ def to(self, *args, **kwargs):
162
+ """Return a copy of the Results object with tensors on the specified device and dtype."""
163
+ return self._apply("to", *args, **kwargs)
164
+
165
+ def new(self):
166
+ """Return a new Results object with the same image, path, and names."""
167
+ return Results(orig_img=self.orig_img, path=self.path, names=self.names)
168
+
169
+ def plot(
170
+ self,
171
+ conf=True,
172
+ line_width=None,
173
+ font_size=None,
174
+ font="Arial.ttf",
175
+ pil=False,
176
+ img=None,
177
+ im_gpu=None,
178
+ kpt_radius=5,
179
+ kpt_line=True,
180
+ labels=True,
181
+ boxes=True,
182
+ masks=True,
183
+ probs=True,
184
+ ):
185
+ """
186
+ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
187
+
188
+ Args:
189
+ conf (bool): Whether to plot the detection confidence score.
190
+ line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
191
+ font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
192
+ font (str): The font to use for the text.
193
+ pil (bool): Whether to return the image as a PIL Image.
194
+ img (numpy.ndarray): Plot to another image. if not, plot to original image.
195
+ im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
196
+ kpt_radius (int, optional): Radius of the drawn keypoints. Default is 5.
197
+ kpt_line (bool): Whether to draw lines connecting keypoints.
198
+ labels (bool): Whether to plot the label of bounding boxes.
199
+ boxes (bool): Whether to plot the bounding boxes.
200
+ masks (bool): Whether to plot the masks.
201
+ probs (bool): Whether to plot classification probability
202
+
203
+ Returns:
204
+ (numpy.ndarray): A numpy array of the annotated image.
205
+
206
+ Example:
207
+ ```python
208
+ from PIL import Image
209
+ from ultralytics import YOLO
210
+
211
+ model = YOLO('yolov8n.pt')
212
+ results = model('bus.jpg') # results list
213
+ for r in results:
214
+ im_array = r.plot() # plot a BGR numpy array of predictions
215
+ im = Image.fromarray(im_array[..., ::-1]) # RGB PIL image
216
+ im.show() # show image
217
+ im.save('results.jpg') # save image
218
+ ```
219
+ """
220
+ if img is None and isinstance(self.orig_img, torch.Tensor):
221
+ img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()
222
+
223
+ names = self.names
224
+ is_obb = self.obb is not None
225
+ pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
226
+ pred_masks, show_masks = self.masks, masks
227
+ pred_probs, show_probs = self.probs, probs
228
+ annotator = Annotator(
229
+ deepcopy(self.orig_img if img is None else img),
230
+ line_width,
231
+ font_size,
232
+ font,
233
+ pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
234
+ example=names,
235
+ )
236
+
237
+ # Plot Segment results
238
+ if pred_masks and show_masks:
239
+ if im_gpu is None:
240
+ img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
241
+ im_gpu = (
242
+ torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device)
243
+ .permute(2, 0, 1)
244
+ .flip(0)
245
+ .contiguous()
246
+ / 255
247
+ )
248
+ idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
249
+ annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
250
+
251
+ # Plot Detect results
252
+ if pred_boxes is not None and show_boxes:
253
+ for d in reversed(pred_boxes):
254
+ c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
255
+ name = ("" if id is None else f"id:{id} ") + names[c]
256
+ label = (f"{name} {conf:.2f}" if conf else name) if labels else None
257
+ box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
258
+ annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
259
+
260
+ # Plot Classify results
261
+ if pred_probs is not None and show_probs:
262
+ text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5)
263
+ x = round(self.orig_shape[0] * 0.03)
264
+ annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
265
+
266
+ # Plot Pose results
267
+ if self.keypoints is not None:
268
+ for k in reversed(self.keypoints.data):
269
+ annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
270
+
271
+ return annotator.result()
272
+
273
+ def verbose(self):
274
+ """Return log string for each task."""
275
+ log_string = ""
276
+ probs = self.probs
277
+ boxes = self.boxes
278
+ if len(self) == 0:
279
+ return log_string if probs is not None else f"{log_string}(no detections), "
280
+ if probs is not None:
281
+ log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
282
+ if boxes:
283
+ for c in boxes.cls.unique():
284
+ n = (boxes.cls == c).sum() # detections per class
285
+ log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "
286
+ return log_string
287
+
288
+ def save_txt(self, txt_file, save_conf=False):
289
+ """
290
+ Save predictions into txt file.
291
+
292
+ Args:
293
+ txt_file (str): txt file path.
294
+ save_conf (bool): save confidence score or not.
295
+ """
296
+ is_obb = self.obb is not None
297
+ boxes = self.obb if is_obb else self.boxes
298
+ masks = self.masks
299
+ probs = self.probs
300
+ kpts = self.keypoints
301
+ texts = []
302
+ if probs is not None:
303
+ # Classify
304
+ [texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5]
305
+ elif boxes:
306
+ # Detect/segment/pose
307
+ for j, d in enumerate(boxes):
308
+ c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
309
+ line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1)))
310
+ if masks:
311
+ seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
312
+ line = (c, *seg)
313
+ if kpts is not None:
314
+ kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
315
+ line += (*kpt.reshape(-1).tolist(),)
316
+ line += (conf,) * save_conf + (() if id is None else (id,))
317
+ texts.append(("%g " * len(line)).rstrip() % line)
318
+
319
+ if texts:
320
+ Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
321
+ with open(txt_file, "a") as f:
322
+ f.writelines(text + "\n" for text in texts)
323
+
324
+ def save_crop(self, save_dir, file_name=Path("im.jpg")):
325
+ """
326
+ Save cropped predictions to `save_dir/cls/file_name.jpg`.
327
+
328
+ Args:
329
+ save_dir (str | pathlib.Path): Save path.
330
+ file_name (str | pathlib.Path): File name.
331
+ """
332
+ if self.probs is not None:
333
+ LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.")
334
+ return
335
+ if self.obb is not None:
336
+ LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.")
337
+ return
338
+ for d in self.boxes:
339
+ save_one_box(
340
+ d.xyxy,
341
+ self.orig_img.copy(),
342
+ file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg",
343
+ BGR=True,
344
+ )
345
+
346
+ def tojson(self, normalize=False):
347
+ """Convert the object to JSON format."""
348
+ if self.probs is not None:
349
+ LOGGER.warning("Warning: Classify task do not support `tojson` yet.")
350
+ return
351
+
352
+ import json
353
+
354
+ # Create list of detection dictionaries
355
+ results = []
356
+ data = self.boxes.data.cpu().tolist()
357
+ h, w = self.orig_shape if normalize else (1, 1)
358
+ for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
359
+ box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h}
360
+ conf = row[-2]
361
+ class_id = int(row[-1])
362
+ name = self.names[class_id]
363
+ result = {"name": name, "class": class_id, "confidence": conf, "box": box}
364
+ if self.boxes.is_track:
365
+ result["track_id"] = int(row[-3]) # track ID
366
+ if self.masks:
367
+ x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
368
+ result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()}
369
+ if self.keypoints is not None:
370
+ x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
371
+ result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
372
+ results.append(result)
373
+
374
+ # Convert detections to JSON
375
+ return json.dumps(results, indent=2)
376
+
377
+
378
+ class Boxes(BaseTensor):
379
+ """
380
+ A class for storing and manipulating detection boxes.
381
+
382
+ Args:
383
+ boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
384
+ with shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values.
385
+ If present, the third last column contains track IDs.
386
+ orig_shape (tuple): Original image size, in the format (height, width).
387
+
388
+ Attributes:
389
+ xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
390
+ conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
391
+ cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
392
+ id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
393
+ xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
394
+ xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
395
+ xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
396
+ data (torch.Tensor): The raw bboxes tensor (alias for `boxes`).
397
+
398
+ Methods:
399
+ cpu(): Move the object to CPU memory.
400
+ numpy(): Convert the object to a numpy array.
401
+ cuda(): Move the object to CUDA memory.
402
+ to(*args, **kwargs): Move the object to the specified device.
403
+ """
404
+
405
+ def __init__(self, boxes, orig_shape) -> None:
406
+ """Initialize the Boxes class."""
407
+ if boxes.ndim == 1:
408
+ boxes = boxes[None, :]
409
+ n = boxes.shape[-1]
410
+ assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
411
+ super().__init__(boxes, orig_shape)
412
+ self.is_track = n == 7
413
+ self.orig_shape = orig_shape
414
+
415
+ @property
416
+ def xyxy(self):
417
+ """Return the boxes in xyxy format."""
418
+ return self.data[:, :4]
419
+
420
+ @property
421
+ def conf(self):
422
+ """Return the confidence values of the boxes."""
423
+ return self.data[:, -2]
424
+
425
+ @property
426
+ def cls(self):
427
+ """Return the class values of the boxes."""
428
+ return self.data[:, -1]
429
+
430
+ @property
431
+ def id(self):
432
+ """Return the track IDs of the boxes (if available)."""
433
+ return self.data[:, -3] if self.is_track else None
434
+
435
+ @property
436
+ @lru_cache(maxsize=2) # maxsize 1 should suffice
437
+ def xywh(self):
438
+ """Return the boxes in xywh format."""
439
+ return ops.xyxy2xywh(self.xyxy)
440
+
441
+ @property
442
+ @lru_cache(maxsize=2)
443
+ def xyxyn(self):
444
+ """Return the boxes in xyxy format normalized by original image size."""
445
+ xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
446
+ xyxy[..., [0, 2]] /= self.orig_shape[1]
447
+ xyxy[..., [1, 3]] /= self.orig_shape[0]
448
+ return xyxy
449
+
450
+ @property
451
+ @lru_cache(maxsize=2)
452
+ def xywhn(self):
453
+ """Return the boxes in xywh format normalized by original image size."""
454
+ xywh = ops.xyxy2xywh(self.xyxy)
455
+ xywh[..., [0, 2]] /= self.orig_shape[1]
456
+ xywh[..., [1, 3]] /= self.orig_shape[0]
457
+ return xywh
458
+
459
+
460
+ class Masks(BaseTensor):
461
+ """
462
+ A class for storing and manipulating detection masks.
463
+
464
+ Attributes:
465
+ xy (list): A list of segments in pixel coordinates.
466
+ xyn (list): A list of normalized segments.
467
+
468
+ Methods:
469
+ cpu(): Returns the masks tensor on CPU memory.
470
+ numpy(): Returns the masks tensor as a numpy array.
471
+ cuda(): Returns the masks tensor on GPU memory.
472
+ to(device, dtype): Returns the masks tensor with the specified device and dtype.
473
+ """
474
+
475
+ def __init__(self, masks, orig_shape) -> None:
476
+ """Initialize the Masks class with the given masks tensor and original image shape."""
477
+ if masks.ndim == 2:
478
+ masks = masks[None, :]
479
+ super().__init__(masks, orig_shape)
480
+
481
+ @property
482
+ @lru_cache(maxsize=1)
483
+ def xyn(self):
484
+ """Return normalized segments."""
485
+ return [
486
+ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
487
+ for x in ops.masks2segments(self.data)
488
+ ]
489
+
490
+ @property
491
+ @lru_cache(maxsize=1)
492
+ def xy(self):
493
+ """Return segments in pixel coordinates."""
494
+ return [
495
+ ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
496
+ for x in ops.masks2segments(self.data)
497
+ ]
498
+
499
+
500
+ class Keypoints(BaseTensor):
501
+ """
502
+ A class for storing and manipulating detection keypoints.
503
+
504
+ Attributes:
505
+ xy (torch.Tensor): A collection of keypoints containing x, y coordinates for each detection.
506
+ xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
507
+ conf (torch.Tensor): Confidence values associated with keypoints if available, otherwise None.
508
+
509
+ Methods:
510
+ cpu(): Returns a copy of the keypoints tensor on CPU memory.
511
+ numpy(): Returns a copy of the keypoints tensor as a numpy array.
512
+ cuda(): Returns a copy of the keypoints tensor on GPU memory.
513
+ to(device, dtype): Returns a copy of the keypoints tensor with the specified device and dtype.
514
+ """
515
+
516
+ @smart_inference_mode() # avoid keypoints < conf in-place error
517
+ def __init__(self, keypoints, orig_shape) -> None:
518
+ """Initializes the Keypoints object with detection keypoints and original image size."""
519
+ if keypoints.ndim == 2:
520
+ keypoints = keypoints[None, :]
521
+ if keypoints.shape[2] == 3: # x, y, conf
522
+ mask = keypoints[..., 2] < 0.5 # points with conf < 0.5 (not visible)
523
+ keypoints[..., :2][mask] = 0
524
+ super().__init__(keypoints, orig_shape)
525
+ self.has_visible = self.data.shape[-1] == 3
526
+
527
+ @property
528
+ @lru_cache(maxsize=1)
529
+ def xy(self):
530
+ """Returns x, y coordinates of keypoints."""
531
+ return self.data[..., :2]
532
+
533
+ @property
534
+ @lru_cache(maxsize=1)
535
+ def xyn(self):
536
+ """Returns normalized x, y coordinates of keypoints."""
537
+ xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
538
+ xy[..., 0] /= self.orig_shape[1]
539
+ xy[..., 1] /= self.orig_shape[0]
540
+ return xy
541
+
542
+ @property
543
+ @lru_cache(maxsize=1)
544
+ def conf(self):
545
+ """Returns confidence values of keypoints if available, else None."""
546
+ return self.data[..., 2] if self.has_visible else None
547
+
548
+
549
+ class Probs(BaseTensor):
550
+ """
551
+ A class for storing and manipulating classification predictions.
552
+
553
+ Attributes:
554
+ top1 (int): Index of the top 1 class.
555
+ top5 (list[int]): Indices of the top 5 classes.
556
+ top1conf (torch.Tensor): Confidence of the top 1 class.
557
+ top5conf (torch.Tensor): Confidences of the top 5 classes.
558
+
559
+ Methods:
560
+ cpu(): Returns a copy of the probs tensor on CPU memory.
561
+ numpy(): Returns a copy of the probs tensor as a numpy array.
562
+ cuda(): Returns a copy of the probs tensor on GPU memory.
563
+ to(): Returns a copy of the probs tensor with the specified device and dtype.
564
+ """
565
+
566
+ def __init__(self, probs, orig_shape=None) -> None:
567
+ """Initialize the Probs class with classification probabilities and optional original shape of the image."""
568
+ super().__init__(probs, orig_shape)
569
+
570
+ @property
571
+ @lru_cache(maxsize=1)
572
+ def top1(self):
573
+ """Return the index of top 1."""
574
+ return int(self.data.argmax())
575
+
576
+ @property
577
+ @lru_cache(maxsize=1)
578
+ def top5(self):
579
+ """Return the indices of top 5."""
580
+ return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
581
+
582
+ @property
583
+ @lru_cache(maxsize=1)
584
+ def top1conf(self):
585
+ """Return the confidence of top 1."""
586
+ return self.data[self.top1]
587
+
588
+ @property
589
+ @lru_cache(maxsize=1)
590
+ def top5conf(self):
591
+ """Return the confidences of top 5."""
592
+ return self.data[self.top5]
593
+
594
+
595
+ class OBB(BaseTensor):
596
+ """
597
+ A class for storing and manipulating Oriented Bounding Boxes (OBB).
598
+
599
+ Args:
600
+ boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
601
+ with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.
602
+ If present, the third last column contains track IDs, and the fifth column from the left contains rotation.
603
+ orig_shape (tuple): Original image size, in the format (height, width).
604
+
605
+ Attributes:
606
+ xywhr (torch.Tensor | numpy.ndarray): The boxes in [x_center, y_center, width, height, rotation] format.
607
+ conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
608
+ cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
609
+ id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
610
+ xyxyxyxyn (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format normalized by original image size.
611
+ xyxyxyxy (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format.
612
+ xyxy (torch.Tensor | numpy.ndarray): The horizontal boxes in xyxyxyxy format.
613
+ data (torch.Tensor): The raw OBB tensor (alias for `boxes`).
614
+
615
+ Methods:
616
+ cpu(): Move the object to CPU memory.
617
+ numpy(): Convert the object to a numpy array.
618
+ cuda(): Move the object to CUDA memory.
619
+ to(*args, **kwargs): Move the object to the specified device.
620
+ """
621
+
622
+ def __init__(self, boxes, orig_shape) -> None:
623
+ """Initialize the Boxes class."""
624
+ if boxes.ndim == 1:
625
+ boxes = boxes[None, :]
626
+ n = boxes.shape[-1]
627
+ assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
628
+ super().__init__(boxes, orig_shape)
629
+ self.is_track = n == 8
630
+ self.orig_shape = orig_shape
631
+
632
+ @property
633
+ def xywhr(self):
634
+ """Return the rotated boxes in xywhr format."""
635
+ return self.data[:, :5]
636
+
637
+ @property
638
+ def conf(self):
639
+ """Return the confidence values of the boxes."""
640
+ return self.data[:, -2]
641
+
642
+ @property
643
+ def cls(self):
644
+ """Return the class values of the boxes."""
645
+ return self.data[:, -1]
646
+
647
+ @property
648
+ def id(self):
649
+ """Return the track IDs of the boxes (if available)."""
650
+ return self.data[:, -3] if self.is_track else None
651
+
652
+ @property
653
+ @lru_cache(maxsize=2)
654
+ def xyxyxyxy(self):
655
+ """Return the boxes in xyxyxyxy format, (N, 4, 2)."""
656
+ return ops.xywhr2xyxyxyxy(self.xywhr)
657
+
658
+ @property
659
+ @lru_cache(maxsize=2)
660
+ def xyxyxyxyn(self):
661
+ """Return the boxes in xyxyxyxy format, (N, 4, 2)."""
662
+ xyxyxyxyn = self.xyxyxyxy.clone() if isinstance(self.xyxyxyxy, torch.Tensor) else np.copy(self.xyxyxyxy)
663
+ xyxyxyxyn[..., 0] /= self.orig_shape[1]
664
+ xyxyxyxyn[..., 1] /= self.orig_shape[1]
665
+ return xyxyxyxyn
666
+
667
+ @property
668
+ @lru_cache(maxsize=2)
669
+ def xyxy(self):
670
+ """
671
+ Return the horizontal boxes in xyxy format, (N, 4).
672
+
673
+ Accepts both torch and numpy boxes.
674
+ """
675
+ x1 = self.xyxyxyxy[..., 0].min(1).values
676
+ x2 = self.xyxyxyxy[..., 0].max(1).values
677
+ y1 = self.xyxyxyxy[..., 1].min(1).values
678
+ y2 = self.xyxyxyxy[..., 1].max(1).values
679
+ xyxy = [x1, y1, x2, y2]
680
+ return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
yolov8_model/ultralytics/engine/trainer.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Train a model on a dataset.
4
+
5
+ Usage:
6
+ $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
7
+ """
8
+
9
+ import math
10
+ import os
11
+ import subprocess
12
+ import time
13
+ import warnings
14
+ from copy import deepcopy
15
+ from datetime import datetime, timedelta
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import torch
20
+ from torch import distributed as dist
21
+ from torch import nn, optim
22
+
23
+ from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
24
+ from yolov8_model.ultralytics.data.utils import check_cls_dataset, check_det_dataset
25
+ from yolov8_model.ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
26
+ from yolov8_model.ultralytics.utils import (
27
+ DEFAULT_CFG,
28
+ LOGGER,
29
+ RANK,
30
+ TQDM,
31
+ __version__,
32
+ callbacks,
33
+ clean_url,
34
+ colorstr,
35
+ emojis,
36
+ yaml_save,
37
+ )
38
+ from yolov8_model.ultralytics.utils.autobatch import check_train_batch_size
39
+ from yolov8_model.ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
40
+ from yolov8_model.ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
41
+ from yolov8_model.ultralytics.utils.files import get_latest_run
42
+ from yolov8_model.ultralytics.utils.torch_utils import (
43
+ EarlyStopping,
44
+ ModelEMA,
45
+ de_parallel,
46
+ init_seeds,
47
+ one_cycle,
48
+ select_device,
49
+ strip_optimizer,
50
+ )
51
+ from yolov8_model.ultralytics.nn.extra_modules.kernel_warehouse import get_temperature
52
+
53
+
54
+ class BaseTrainer:
55
+ """
56
+ BaseTrainer.
57
+
58
+ A base class for creating trainers.
59
+
60
+ Attributes:
61
+ args (SimpleNamespace): Configuration for the trainer.
62
+ validator (BaseValidator): Validator instance.
63
+ model (nn.Module): Model instance.
64
+ callbacks (defaultdict): Dictionary of callbacks.
65
+ save_dir (Path): Directory to save results.
66
+ wdir (Path): Directory to save weights.
67
+ last (Path): Path to the last checkpoint.
68
+ best (Path): Path to the best checkpoint.
69
+ save_period (int): Save checkpoint every x epochs (disabled if < 1).
70
+ batch_size (int): Batch size for training.
71
+ epochs (int): Number of epochs to train for.
72
+ start_epoch (int): Starting epoch for training.
73
+ device (torch.device): Device to use for training.
74
+ amp (bool): Flag to enable AMP (Automatic Mixed Precision).
75
+ scaler (amp.GradScaler): Gradient scaler for AMP.
76
+ data (str): Path to data.
77
+ trainset (torch.utils.data.Dataset): Training dataset.
78
+ testset (torch.utils.data.Dataset): Testing dataset.
79
+ ema (nn.Module): EMA (Exponential Moving Average) of the model.
80
+ resume (bool): Resume training from a checkpoint.
81
+ lf (nn.Module): Loss function.
82
+ scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
83
+ best_fitness (float): The best fitness value achieved.
84
+ fitness (float): Current fitness value.
85
+ loss (float): Current loss value.
86
+ tloss (float): Total loss value.
87
+ loss_names (list): List of loss names.
88
+ csv (Path): Path to results CSV file.
89
+ """
90
+
91
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
92
+ """
93
+ Initializes the BaseTrainer class.
94
+
95
+ Args:
96
+ cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
97
+ overrides (dict, optional): Configuration overrides. Defaults to None.
98
+ """
99
+ self.args = get_cfg(cfg, overrides)
100
+ self.check_resume(overrides)
101
+ self.device = select_device(self.args.device, self.args.batch)
102
+ self.validator = None
103
+ self.metrics = None
104
+ self.plots = {}
105
+ init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
106
+
107
+ # Dirs
108
+ self.save_dir = get_save_dir(self.args)
109
+ self.args.name = self.save_dir.name # update name for loggers
110
+ self.wdir = self.save_dir / "weights" # weights dir
111
+ if RANK in (-1, 0):
112
+ self.wdir.mkdir(parents=True, exist_ok=True) # make dir
113
+ self.args.save_dir = str(self.save_dir)
114
+ yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
115
+ self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
116
+ self.save_period = self.args.save_period
117
+
118
+ self.batch_size = self.args.batch
119
+ self.epochs = self.args.epochs
120
+ self.start_epoch = 0
121
+ if RANK == -1:
122
+ print_args(vars(self.args))
123
+
124
+ # Device
125
+ if self.device.type in ("cpu", "mps"):
126
+ self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
127
+
128
+ # Model and Dataset
129
+ self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
130
+ try:
131
+ if self.args.task == "classify":
132
+ self.data = check_cls_dataset(self.args.data)
133
+ elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"):
134
+ self.data = check_det_dataset(self.args.data)
135
+ if "yaml_file" in self.data:
136
+ self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
137
+ except Exception as e:
138
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
139
+
140
+ self.trainset, self.testset = self.get_dataset(self.data)
141
+ self.ema = None
142
+
143
+ # Optimization utils init
144
+ self.lf = None
145
+ self.scheduler = None
146
+
147
+ # Epoch level metrics
148
+ self.best_fitness = None
149
+ self.fitness = None
150
+ self.loss = None
151
+ self.tloss = None
152
+ self.loss_names = ["Loss"]
153
+ self.csv = self.save_dir / "results.csv"
154
+ self.plot_idx = [0, 1, 2]
155
+
156
+ # Callbacks
157
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
158
+ if RANK in (-1, 0):
159
+ callbacks.add_integration_callbacks(self)
160
+
161
+ def add_callback(self, event: str, callback):
162
+ """Appends the given callback."""
163
+ self.callbacks[event].append(callback)
164
+
165
+ def set_callback(self, event: str, callback):
166
+ """Overrides the existing callbacks with the given callback."""
167
+ self.callbacks[event] = [callback]
168
+
169
+ def run_callbacks(self, event: str):
170
+ """Run all existing callbacks associated with a particular event."""
171
+ for callback in self.callbacks.get(event, []):
172
+ callback(self)
173
+
174
+ def train(self):
175
+ """Allow device='', device=None on Multi-GPU systems to default to device=0."""
176
+ if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
177
+ world_size = len(self.args.device.split(","))
178
+ elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
179
+ world_size = len(self.args.device)
180
+ elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
181
+ world_size = 1 # default to device 0
182
+ else: # i.e. device='cpu' or 'mps'
183
+ world_size = 0
184
+
185
+ # Run subprocess if DDP training, else train normally
186
+ if world_size > 1 and "LOCAL_RANK" not in os.environ:
187
+ # Argument checks
188
+ if self.args.rect:
189
+ LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
190
+ self.args.rect = False
191
+ if self.args.batch == -1:
192
+ LOGGER.warning(
193
+ "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
194
+ "default 'batch=16'"
195
+ )
196
+ self.args.batch = 16
197
+
198
+ # Command
199
+ cmd, file = generate_ddp_command(world_size, self)
200
+ try:
201
+ LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
202
+ subprocess.run(cmd, check=True)
203
+ except Exception as e:
204
+ raise e
205
+ finally:
206
+ ddp_cleanup(self, str(file))
207
+
208
+ else:
209
+ self._do_train(world_size)
210
+
211
+ def _setup_scheduler(self):
212
+ """Initialize training learning rate scheduler."""
213
+ if self.args.cos_lr:
214
+ self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
215
+ else:
216
+ self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
217
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
218
+
219
+ def _setup_ddp(self, world_size):
220
+ """Initializes and sets the DistributedDataParallel parameters for training."""
221
+ torch.cuda.set_device(RANK)
222
+ self.device = torch.device("cuda", RANK)
223
+ # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
224
+ os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
225
+ dist.init_process_group(
226
+ "nccl" if dist.is_nccl_available() else "gloo",
227
+ timeout=timedelta(seconds=10800), # 3 hours
228
+ rank=RANK,
229
+ world_size=world_size,
230
+ )
231
+
232
+ def _setup_train(self, world_size):
233
+ """Builds dataloaders and optimizer on correct rank process."""
234
+
235
+ # Model
236
+ self.run_callbacks("on_pretrain_routine_start")
237
+ ckpt = self.setup_model()
238
+ self.model = self.model.to(self.device)
239
+ self.set_model_attributes()
240
+
241
+ # Freeze layers
242
+ freeze_list = (
243
+ self.args.freeze
244
+ if isinstance(self.args.freeze, list)
245
+ else range(self.args.freeze)
246
+ if isinstance(self.args.freeze, int)
247
+ else []
248
+ )
249
+ always_freeze_names = [".dfl"] # always freeze these layers
250
+ freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
251
+ for k, v in self.model.named_parameters():
252
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
253
+ if any(x in k for x in freeze_layer_names):
254
+ LOGGER.info(f"Freezing layer '{k}'")
255
+ v.requires_grad = False
256
+ elif not v.requires_grad:
257
+ LOGGER.info(
258
+ f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
259
+ "See ultralytics.engine.trainer for customization of frozen layers."
260
+ )
261
+ v.requires_grad = True
262
+
263
+ # Check AMP
264
+ self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
265
+ if self.amp and RANK in (-1, 0): # Single-GPU and DDP
266
+ callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
267
+ self.amp = torch.tensor(check_amp(self.model), device=self.device)
268
+ callbacks.default_callbacks = callbacks_backup # restore callbacks
269
+ if RANK > -1 and world_size > 1: # DDP
270
+ dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
271
+ self.amp = bool(self.amp) # as boolean
272
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
273
+ if world_size > 1:
274
+ self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
275
+
276
+ # Check imgsz
277
+ gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
278
+ self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
279
+ self.stride = gs # for multi-scale training
280
+
281
+ # Batch size
282
+ if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
283
+ self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
284
+
285
+ # Dataloaders
286
+ batch_size = self.batch_size // max(world_size, 1)
287
+ self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
288
+ if RANK in (-1, 0):
289
+ # NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
290
+ self.test_loader = self.get_dataloader(
291
+ self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
292
+ )
293
+ self.validator = self.get_validator()
294
+ metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
295
+ self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
296
+ self.ema = ModelEMA(self.model)
297
+ if self.args.plots:
298
+ self.plot_training_labels()
299
+
300
+ # Optimizer
301
+ self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
302
+ weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
303
+ iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
304
+ self.optimizer = self.build_optimizer(
305
+ model=self.model,
306
+ name=self.args.optimizer,
307
+ lr=self.args.lr0,
308
+ momentum=self.args.momentum,
309
+ decay=weight_decay,
310
+ iterations=iterations,
311
+ )
312
+ # Scheduler
313
+ self._setup_scheduler()
314
+ self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
315
+ self.resume_training(ckpt)
316
+ self.scheduler.last_epoch = self.start_epoch - 1 # do not move
317
+ self.run_callbacks("on_pretrain_routine_end")
318
+
319
+ def _do_train(self, world_size=1):
320
+ """Train completed, evaluate and plot if specified by arguments."""
321
+ if world_size > 1:
322
+ self._setup_ddp(world_size)
323
+ self._setup_train(world_size)
324
+
325
+ nb = len(self.train_loader) # number of batches
326
+ nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
327
+ last_opt_step = -1
328
+ self.epoch_time = None
329
+ self.epoch_time_start = time.time()
330
+ self.train_time_start = time.time()
331
+ self.run_callbacks("on_train_start")
332
+ LOGGER.info(
333
+ f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
334
+ f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
335
+ f"Logging results to {colorstr('bold', self.save_dir)}\n"
336
+ f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
337
+ )
338
+ if self.args.close_mosaic:
339
+ base_idx = (self.epochs - self.args.close_mosaic) * nb
340
+ self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
341
+ epoch = self.epochs # predefine for resume fully trained model edge cases
342
+ for epoch in range(self.start_epoch, self.epochs):
343
+ self.epoch = epoch
344
+ self.run_callbacks("on_train_epoch_start")
345
+ self.model.train()
346
+ if RANK != -1:
347
+ self.train_loader.sampler.set_epoch(epoch)
348
+ pbar = enumerate(self.train_loader)
349
+ # Update dataloader attributes (optional)
350
+ if epoch == (self.epochs - self.args.close_mosaic):
351
+ self._close_dataloader_mosaic()
352
+ self.train_loader.reset()
353
+
354
+ if RANK in (-1, 0):
355
+ LOGGER.info(self.progress_string())
356
+ pbar = TQDM(enumerate(self.train_loader), total=nb)
357
+ self.tloss = None
358
+ self.optimizer.zero_grad()
359
+ for i, batch in pbar:
360
+ self.run_callbacks("on_train_batch_start")
361
+ # Warmup
362
+ ni = i + nb * epoch
363
+ if ni <= nw:
364
+ xi = [0, nw] # x interp
365
+ self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
366
+ for j, x in enumerate(self.optimizer.param_groups):
367
+ # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
368
+ x["lr"] = np.interp(
369
+ ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
370
+ )
371
+ if "momentum" in x:
372
+ x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
373
+
374
+ if hasattr(self.model, 'net_update_temperature'):
375
+ temp = get_temperature(i + 1, epoch, len(self.train_loader), temp_epoch=20, temp_init_value=1.0)
376
+ self.model.net_update_temperature(temp)
377
+
378
+ # Forward
379
+ with torch.cuda.amp.autocast(self.amp):
380
+ batch = self.preprocess_batch(batch)
381
+ self.loss, self.loss_items = self.model(batch)
382
+ if RANK != -1:
383
+ self.loss *= world_size
384
+ self.tloss = (
385
+ (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
386
+ )
387
+
388
+ # Backward
389
+ self.scaler.scale(self.loss).backward()
390
+
391
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
392
+ if ni - last_opt_step >= self.accumulate:
393
+ self.optimizer_step()
394
+ last_opt_step = ni
395
+
396
+ # Timed stopping
397
+ if self.args.time:
398
+ self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
399
+ if RANK != -1: # if DDP training
400
+ broadcast_list = [self.stop if RANK == 0 else None]
401
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
402
+ self.stop = broadcast_list[0]
403
+ if self.stop: # training time exceeded
404
+ break
405
+
406
+ # Log
407
+ mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
408
+ loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
409
+ losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
410
+ if RANK in (-1, 0):
411
+ pbar.set_description(
412
+ ("%11s" * 2 + "%11.4g" * (2 + loss_len))
413
+ % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
414
+ )
415
+ self.run_callbacks("on_batch_end")
416
+ if self.args.plots and ni in self.plot_idx:
417
+ self.plot_training_samples(batch, ni)
418
+
419
+ self.run_callbacks("on_train_batch_end")
420
+
421
+ self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
422
+ self.run_callbacks("on_train_epoch_end")
423
+ if RANK in (-1, 0):
424
+ final_epoch = epoch + 1 == self.epochs
425
+ self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
426
+
427
+ # Validation
428
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
429
+ self.metrics, self.fitness = self.validate()
430
+ self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
431
+ self.stop |= self.stopper(epoch + 1, self.fitness)
432
+ if self.args.time:
433
+ self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
434
+
435
+ # Save model
436
+ if self.args.save or final_epoch:
437
+ self.save_model()
438
+ self.run_callbacks("on_model_save")
439
+
440
+ # Scheduler
441
+ t = time.time()
442
+ self.epoch_time = t - self.epoch_time_start
443
+ self.epoch_time_start = t
444
+ with warnings.catch_warnings():
445
+ warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
446
+ if self.args.time:
447
+ mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
448
+ self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
449
+ self._setup_scheduler()
450
+ self.scheduler.last_epoch = self.epoch # do not move
451
+ self.stop |= epoch >= self.epochs # stop if exceeded epochs
452
+ self.scheduler.step()
453
+ self.run_callbacks("on_fit_epoch_end")
454
+ torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
455
+
456
+ # Early Stopping
457
+ if RANK != -1: # if DDP training
458
+ broadcast_list = [self.stop if RANK == 0 else None]
459
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
460
+ self.stop = broadcast_list[0]
461
+ if self.stop:
462
+ break # must break all DDP ranks
463
+
464
+ if RANK in (-1, 0):
465
+ # Do final val with best.pt
466
+ LOGGER.info(
467
+ f"\n{epoch - self.start_epoch + 1} epochs completed in "
468
+ f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
469
+ )
470
+ self.final_eval()
471
+ if self.args.plots:
472
+ self.plot_metrics()
473
+ self.run_callbacks("on_train_end")
474
+ torch.cuda.empty_cache()
475
+ self.run_callbacks("teardown")
476
+
477
+ def save_model(self):
478
+ """Save model training checkpoints with additional metadata."""
479
+ import pandas as pd # scope for faster startup
480
+
481
+ metrics = {**self.metrics, **{"fitness": self.fitness}}
482
+ results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
483
+ ckpt = {
484
+ "epoch": self.epoch,
485
+ "best_fitness": self.best_fitness,
486
+ "model": deepcopy(de_parallel(self.model)).half(),
487
+ "ema": deepcopy(self.ema.ema).half(),
488
+ "updates": self.ema.updates,
489
+ "optimizer": self.optimizer.state_dict(),
490
+ "train_args": vars(self.args), # save as dict
491
+ "train_metrics": metrics,
492
+ "train_results": results,
493
+ "date": datetime.now().isoformat(),
494
+ "version": __version__,
495
+ }
496
+
497
+ # Save last and best
498
+ torch.save(ckpt, self.last)
499
+ if self.best_fitness == self.fitness:
500
+ torch.save(ckpt, self.best)
501
+ if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
502
+ torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
503
+
504
+ @staticmethod
505
+ def get_dataset(data):
506
+ """
507
+ Get train, val path from data dict if it exists.
508
+
509
+ Returns None if data format is not recognized.
510
+ """
511
+ return data["train"], data.get("val") or data.get("test")
512
+
513
+ def setup_model(self):
514
+ """Load/create/download model for any task."""
515
+ if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
516
+ return
517
+
518
+ model, weights = self.model, None
519
+ ckpt = None
520
+ if str(model).endswith(".pt"):
521
+ weights, ckpt = attempt_load_one_weight(model)
522
+ cfg = ckpt["model"].yaml
523
+ else:
524
+ cfg = model
525
+ self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
526
+ return ckpt
527
+
528
+ def optimizer_step(self):
529
+ """Perform a single step of the training optimizer with gradient clipping and EMA update."""
530
+ self.scaler.unscale_(self.optimizer) # unscale gradients
531
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
532
+ self.scaler.step(self.optimizer)
533
+ self.scaler.update()
534
+ self.optimizer.zero_grad()
535
+ if self.ema:
536
+ self.ema.update(self.model)
537
+
538
+ def preprocess_batch(self, batch):
539
+ """Allows custom preprocessing model inputs and ground truths depending on task type."""
540
+ return batch
541
+
542
+ def validate(self):
543
+ """
544
+ Runs validation on test set using self.validator.
545
+
546
+ The returned dict is expected to contain "fitness" key.
547
+ """
548
+ metrics = self.validator(self)
549
+ fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
550
+ if not self.best_fitness or self.best_fitness < fitness:
551
+ self.best_fitness = fitness
552
+ return metrics, fitness
553
+
554
+ def get_model(self, cfg=None, weights=None, verbose=True):
555
+ """Get model and raise NotImplementedError for loading cfg files."""
556
+ raise NotImplementedError("This task trainer doesn't support loading cfg files")
557
+
558
+ def get_validator(self):
559
+ """Returns a NotImplementedError when the get_validator function is called."""
560
+ raise NotImplementedError("get_validator function not implemented in trainer")
561
+
562
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
563
+ """Returns dataloader derived from torch.data.Dataloader."""
564
+ raise NotImplementedError("get_dataloader function not implemented in trainer")
565
+
566
+ def build_dataset(self, img_path, mode="train", batch=None):
567
+ """Build dataset."""
568
+ raise NotImplementedError("build_dataset function not implemented in trainer")
569
+
570
+ def label_loss_items(self, loss_items=None, prefix="train"):
571
+ """
572
+ Returns a loss dict with labelled training loss items tensor.
573
+
574
+ Note:
575
+ This is not needed for classification but necessary for segmentation & detection
576
+ """
577
+ return {"loss": loss_items} if loss_items is not None else ["loss"]
578
+
579
+ def set_model_attributes(self):
580
+ """To set or update model parameters before training."""
581
+ self.model.names = self.data["names"]
582
+
583
+ def build_targets(self, preds, targets):
584
+ """Builds target tensors for training YOLO model."""
585
+ pass
586
+
587
+ def progress_string(self):
588
+ """Returns a string describing training progress."""
589
+ return ""
590
+
591
+ # TODO: may need to put these following functions into callback
592
+ def plot_training_samples(self, batch, ni):
593
+ """Plots training samples during YOLO training."""
594
+ pass
595
+
596
+ def plot_training_labels(self):
597
+ """Plots training labels for YOLO model."""
598
+ pass
599
+
600
+ def save_metrics(self, metrics):
601
+ """Saves training metrics to a CSV file."""
602
+ keys, vals = list(metrics.keys()), list(metrics.values())
603
+ n = len(metrics) + 1 # number of cols
604
+ s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
605
+ with open(self.csv, "a") as f:
606
+ f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
607
+
608
+ def plot_metrics(self):
609
+ """Plot and display metrics visually."""
610
+ pass
611
+
612
+ def on_plot(self, name, data=None):
613
+ """Registers plots (e.g. to be consumed in callbacks)"""
614
+ path = Path(name)
615
+ self.plots[path] = {"data": data, "timestamp": time.time()}
616
+
617
+ def final_eval(self):
618
+ """Performs final evaluation and validation for object detection YOLO model."""
619
+ for f in self.last, self.best:
620
+ if f.exists():
621
+ strip_optimizer(f) # strip optimizers
622
+ if f is self.best:
623
+ LOGGER.info(f"\nValidating {f}...")
624
+ self.validator.args.plots = self.args.plots
625
+ self.metrics = self.validator(model=f)
626
+ self.metrics.pop("fitness", None)
627
+ self.run_callbacks("on_fit_epoch_end")
628
+
629
+ def check_resume(self, overrides):
630
+ """Check if resume checkpoint exists and update arguments accordingly."""
631
+ resume = self.args.resume
632
+ if resume:
633
+ try:
634
+ exists = isinstance(resume, (str, Path)) and Path(resume).exists()
635
+ last = Path(check_file(resume) if exists else get_latest_run())
636
+
637
+ # Check that resume data YAML exists, otherwise strip to force re-download of dataset
638
+ ckpt_args = attempt_load_weights(last).args
639
+ if not Path(ckpt_args["data"]).exists():
640
+ ckpt_args["data"] = self.args.data
641
+
642
+ resume = True
643
+ self.args = get_cfg(ckpt_args)
644
+ self.args.model = str(last) # reinstate model
645
+ for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
646
+ if k in overrides:
647
+ setattr(self.args, k, overrides[k])
648
+
649
+ except Exception as e:
650
+ raise FileNotFoundError(
651
+ "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
652
+ "i.e. 'yolo train resume model=path/to/last.pt'"
653
+ ) from e
654
+ self.resume = resume
655
+
656
+ def resume_training(self, ckpt):
657
+ """Resume YOLO training from given epoch and best fitness."""
658
+ if ckpt is None:
659
+ return
660
+ best_fitness = 0.0
661
+ start_epoch = ckpt["epoch"] + 1
662
+ if ckpt["optimizer"] is not None:
663
+ self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
664
+ best_fitness = ckpt["best_fitness"]
665
+ if self.ema and ckpt.get("ema"):
666
+ self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
667
+ self.ema.updates = ckpt["updates"]
668
+ if self.resume:
669
+ assert start_epoch > 0, (
670
+ f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
671
+ f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
672
+ )
673
+ LOGGER.info(
674
+ f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
675
+ )
676
+ if self.epochs < start_epoch:
677
+ LOGGER.info(
678
+ f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
679
+ )
680
+ self.epochs += ckpt["epoch"] # finetune additional epochs
681
+ self.best_fitness = best_fitness
682
+ self.start_epoch = start_epoch
683
+ if start_epoch > (self.epochs - self.args.close_mosaic):
684
+ self._close_dataloader_mosaic()
685
+
686
+ def _close_dataloader_mosaic(self):
687
+ """Update dataloaders to stop using mosaic augmentation."""
688
+ if hasattr(self.train_loader.dataset, "mosaic"):
689
+ self.train_loader.dataset.mosaic = False
690
+ if hasattr(self.train_loader.dataset, "close_mosaic"):
691
+ LOGGER.info("Closing dataloader mosaic")
692
+ self.train_loader.dataset.close_mosaic(hyp=self.args)
693
+
694
+ def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
695
+ """
696
+ Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
697
+ weight decay, and number of iterations.
698
+
699
+ Args:
700
+ model (torch.nn.Module): The model for which to build an optimizer.
701
+ name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
702
+ based on the number of iterations. Default: 'auto'.
703
+ lr (float, optional): The learning rate for the optimizer. Default: 0.001.
704
+ momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
705
+ decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
706
+ iterations (float, optional): The number of iterations, which determines the optimizer if
707
+ name is 'auto'. Default: 1e5.
708
+
709
+ Returns:
710
+ (torch.optim.Optimizer): The constructed optimizer.
711
+ """
712
+
713
+ g = [], [], [] # optimizer parameter groups
714
+ bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
715
+ if name == "auto":
716
+ LOGGER.info(
717
+ f"{colorstr('optimizer:')} 'optimizer=auto' found, "
718
+ f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
719
+ f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
720
+ )
721
+ nc = getattr(model, "nc", 10) # number of classes
722
+ lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
723
+ name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
724
+ self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
725
+
726
+ for module_name, module in model.named_modules():
727
+ for param_name, param in module.named_parameters(recurse=False):
728
+ fullname = f"{module_name}.{param_name}" if module_name else param_name
729
+ if "bias" in fullname: # bias (no decay)
730
+ g[2].append(param)
731
+ elif isinstance(module, bn): # weight (no decay)
732
+ g[1].append(param)
733
+ else: # weight (with decay)
734
+ g[0].append(param)
735
+
736
+ if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
737
+ optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
738
+ elif name == "RMSProp":
739
+ optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
740
+ elif name == "SGD":
741
+ optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
742
+ else:
743
+ raise NotImplementedError(
744
+ f"Optimizer '{name}' not found in list of available optimizers "
745
+ f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
746
+ "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
747
+ )
748
+
749
+ optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
750
+ optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
751
+ LOGGER.info(
752
+ f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
753
+ f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
754
+ )
755
+ return optimizer
yolov8_model/ultralytics/engine/tuner.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
4
+ instance segmentation, image classification, pose estimation, and multi-object tracking.
5
+
6
+ Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
7
+ that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
8
+ where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
9
+
10
+ Example:
11
+ Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
12
+ ```python
13
+ from ultralytics import YOLO
14
+
15
+ model = YOLO('yolov8n.pt')
16
+ model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
17
+ ```
18
+ """
19
+ import random
20
+ import shutil
21
+ import subprocess
22
+ import time
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
28
+ from yolov8_model.ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
29
+ from yolov8_model.ultralytics.utils.plotting import plot_tune_results
30
+
31
+
32
+ class Tuner:
33
+ """
34
+ Class responsible for hyperparameter tuning of YOLO models.
35
+
36
+ The class evolves YOLO model hyperparameters over a given number of iterations
37
+ by mutating them according to the search space and retraining the model to evaluate their performance.
38
+
39
+ Attributes:
40
+ space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
41
+ tune_dir (Path): Directory where evolution logs and results will be saved.
42
+ tune_csv (Path): Path to the CSV file where evolution logs are saved.
43
+
44
+ Methods:
45
+ _mutate(hyp: dict) -> dict:
46
+ Mutates the given hyperparameters within the bounds specified in `self.space`.
47
+
48
+ __call__():
49
+ Executes the hyperparameter evolution across multiple iterations.
50
+
51
+ Example:
52
+ Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
53
+ ```python
54
+ from ultralytics import YOLO
55
+
56
+ model = YOLO('yolov8n.pt')
57
+ model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
58
+ ```
59
+
60
+ Tune with custom search space.
61
+ ```python
62
+ from ultralytics import YOLO
63
+
64
+ model = YOLO('yolov8n.pt')
65
+ model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
66
+ ```
67
+ """
68
+
69
+ def __init__(self, args=DEFAULT_CFG, _callbacks=None):
70
+ """
71
+ Initialize the Tuner with configurations.
72
+
73
+ Args:
74
+ args (dict, optional): Configuration for hyperparameter evolution.
75
+ """
76
+ self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
77
+ # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
78
+ "lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
79
+ "lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
80
+ "momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
81
+ "weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
82
+ "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
83
+ "warmup_momentum": (0.0, 0.95), # warmup initial momentum
84
+ "box": (1.0, 20.0), # box loss gain
85
+ "cls": (0.2, 4.0), # cls loss gain (scale with pixels)
86
+ "dfl": (0.4, 6.0), # dfl loss gain
87
+ "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
88
+ "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
89
+ "hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
90
+ "degrees": (0.0, 45.0), # image rotation (+/- deg)
91
+ "translate": (0.0, 0.9), # image translation (+/- fraction)
92
+ "scale": (0.0, 0.95), # image scale (+/- gain)
93
+ "shear": (0.0, 10.0), # image shear (+/- deg)
94
+ "perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
95
+ "flipud": (0.0, 1.0), # image flip up-down (probability)
96
+ "fliplr": (0.0, 1.0), # image flip left-right (probability)
97
+ "mosaic": (0.0, 1.0), # image mixup (probability)
98
+ "mixup": (0.0, 1.0), # image mixup (probability)
99
+ "copy_paste": (0.0, 1.0), # segment copy-paste (probability)
100
+ }
101
+ self.args = get_cfg(overrides=args)
102
+ self.tune_dir = get_save_dir(self.args, name="tune")
103
+ self.tune_csv = self.tune_dir / "tune_results.csv"
104
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
105
+ self.prefix = colorstr("Tuner: ")
106
+ callbacks.add_integration_callbacks(self)
107
+ LOGGER.info(
108
+ f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
109
+ f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
110
+ )
111
+
112
+ def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
113
+ """
114
+ Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
115
+
116
+ Args:
117
+ parent (str): Parent selection method: 'single' or 'weighted'.
118
+ n (int): Number of parents to consider.
119
+ mutation (float): Probability of a parameter mutation in any given iteration.
120
+ sigma (float): Standard deviation for Gaussian random number generator.
121
+
122
+ Returns:
123
+ (dict): A dictionary containing mutated hyperparameters.
124
+ """
125
+ if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
126
+ # Select parent(s)
127
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
128
+ fitness = x[:, 0] # first column
129
+ n = min(n, len(x)) # number of previous results to consider
130
+ x = x[np.argsort(-fitness)][:n] # top n mutations
131
+ w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
132
+ if parent == "single" or len(x) == 1:
133
+ # x = x[random.randint(0, n - 1)] # random selection
134
+ x = x[random.choices(range(n), weights=w)[0]] # weighted selection
135
+ elif parent == "weighted":
136
+ x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
137
+
138
+ # Mutate
139
+ r = np.random # method
140
+ r.seed(int(time.time()))
141
+ g = np.array([v[2] if len(v) == 3 else 1.0 for k, v in self.space.items()]) # gains 0-1
142
+ ng = len(self.space)
143
+ v = np.ones(ng)
144
+ while all(v == 1): # mutate until a change occurs (prevent duplicates)
145
+ v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
146
+ hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
147
+ else:
148
+ hyp = {k: getattr(self.args, k) for k in self.space.keys()}
149
+
150
+ # Constrain to limits
151
+ for k, v in self.space.items():
152
+ hyp[k] = max(hyp[k], v[0]) # lower limit
153
+ hyp[k] = min(hyp[k], v[1]) # upper limit
154
+ hyp[k] = round(hyp[k], 5) # significant digits
155
+
156
+ return hyp
157
+
158
+ def __call__(self, model=None, iterations=10, cleanup=True):
159
+ """
160
+ Executes the hyperparameter evolution process when the Tuner instance is called.
161
+
162
+ This method iterates through the number of iterations, performing the following steps in each iteration:
163
+ 1. Load the existing hyperparameters or initialize new ones.
164
+ 2. Mutate the hyperparameters using the `mutate` method.
165
+ 3. Train a YOLO model with the mutated hyperparameters.
166
+ 4. Log the fitness score and mutated hyperparameters to a CSV file.
167
+
168
+ Args:
169
+ model (Model): A pre-initialized YOLO model to be used for training.
170
+ iterations (int): The number of generations to run the evolution for.
171
+ cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
172
+
173
+ Note:
174
+ The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
175
+ Ensure this path is set correctly in the Tuner instance.
176
+ """
177
+
178
+ t0 = time.time()
179
+ best_save_dir, best_metrics = None, None
180
+ (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
181
+ for i in range(iterations):
182
+ # Mutate hyperparameters
183
+ mutated_hyp = self._mutate()
184
+ LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
185
+
186
+ metrics = {}
187
+ train_args = {**vars(self.args), **mutated_hyp}
188
+ save_dir = get_save_dir(get_cfg(train_args))
189
+ weights_dir = save_dir / "weights"
190
+ ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
191
+ try:
192
+ # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
193
+ cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
194
+ return_code = subprocess.run(cmd, check=True).returncode
195
+ metrics = torch.load(ckpt_file)["train_metrics"]
196
+ assert return_code == 0, "training failed"
197
+
198
+ except Exception as e:
199
+ LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
200
+
201
+ # Save results and mutated_hyp to CSV
202
+ fitness = metrics.get("fitness", 0.0)
203
+ log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
204
+ headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
205
+ with open(self.tune_csv, "a") as f:
206
+ f.write(headers + ",".join(map(str, log_row)) + "\n")
207
+
208
+ # Get best results
209
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
210
+ fitness = x[:, 0] # first column
211
+ best_idx = fitness.argmax()
212
+ best_is_current = best_idx == i
213
+ if best_is_current:
214
+ best_save_dir = save_dir
215
+ best_metrics = {k: round(v, 5) for k, v in metrics.items()}
216
+ for ckpt in weights_dir.glob("*.pt"):
217
+ shutil.copy2(ckpt, self.tune_dir / "weights")
218
+ elif cleanup:
219
+ shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
220
+
221
+ # Plot tune results
222
+ plot_tune_results(self.tune_csv)
223
+
224
+ # Save and print tune results
225
+ header = (
226
+ f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
227
+ f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
228
+ f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
229
+ f'{self.prefix}Best fitness metrics are {best_metrics}\n'
230
+ f'{self.prefix}Best fitness model is {best_save_dir}\n'
231
+ f'{self.prefix}Best fitness hyperparameters are printed below.\n'
232
+ )
233
+ LOGGER.info("\n" + header)
234
+ data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
235
+ yaml_save(
236
+ self.tune_dir / "best_hyperparameters.yaml",
237
+ data=data,
238
+ header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
239
+ )
240
+ yaml_print(self.tune_dir / "best_hyperparameters.yaml")
yolov8_model/ultralytics/engine/validator.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Check a model's accuracy on a test or val split of a dataset.
4
+
5
+ Usage:
6
+ $ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
7
+
8
+ Usage - formats:
9
+ $ yolo mode=val model=yolov8n.pt # PyTorch
10
+ yolov8n.torchscript # TorchScript
11
+ yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
12
+ yolov8n_openvino_model # OpenVINO
13
+ yolov8n.engine # TensorRT
14
+ yolov8n.mlpackage # CoreML (macOS-only)
15
+ yolov8n_saved_model # TensorFlow SavedModel
16
+ yolov8n.pb # TensorFlow GraphDef
17
+ yolov8n.tflite # TensorFlow Lite
18
+ yolov8n_edgetpu.tflite # TensorFlow Edge TPU
19
+ yolov8n_paddle_model # PaddlePaddle
20
+ """
21
+ import json
22
+ import time
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
29
+ from yolov8_model.ultralytics.data.utils import check_cls_dataset, check_det_dataset
30
+ from yolov8_model.ultralytics.nn.autobackend import AutoBackend
31
+ from yolov8_model.ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
32
+ from yolov8_model.ultralytics.utils.checks import check_imgsz
33
+ from yolov8_model.ultralytics.utils.ops import Profile
34
+ from yolov8_model.ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
35
+
36
+
37
+ class BaseValidator:
38
+ """
39
+ BaseValidator.
40
+
41
+ A base class for creating validators.
42
+
43
+ Attributes:
44
+ args (SimpleNamespace): Configuration for the validator.
45
+ dataloader (DataLoader): Dataloader to use for validation.
46
+ pbar (tqdm): Progress bar to update during validation.
47
+ model (nn.Module): Model to validate.
48
+ data (dict): Data dictionary.
49
+ device (torch.device): Device to use for validation.
50
+ batch_i (int): Current batch index.
51
+ training (bool): Whether the model is in training mode.
52
+ names (dict): Class names.
53
+ seen: Records the number of images seen so far during validation.
54
+ stats: Placeholder for statistics during validation.
55
+ confusion_matrix: Placeholder for a confusion matrix.
56
+ nc: Number of classes.
57
+ iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
58
+ jdict (dict): Dictionary to store JSON validation results.
59
+ speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
60
+ batch processing times in milliseconds.
61
+ save_dir (Path): Directory to save results.
62
+ plots (dict): Dictionary to store plots for visualization.
63
+ callbacks (dict): Dictionary to store various callback functions.
64
+ """
65
+
66
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
67
+ """
68
+ Initializes a BaseValidator instance.
69
+
70
+ Args:
71
+ dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
72
+ save_dir (Path, optional): Directory to save results.
73
+ pbar (tqdm.tqdm): Progress bar for displaying progress.
74
+ args (SimpleNamespace): Configuration for the validator.
75
+ _callbacks (dict): Dictionary to store various callback functions.
76
+ """
77
+ self.args = get_cfg(overrides=args)
78
+ self.dataloader = dataloader
79
+ self.pbar = pbar
80
+ self.stride = None
81
+ self.data = None
82
+ self.device = None
83
+ self.batch_i = None
84
+ self.training = True
85
+ self.names = None
86
+ self.seen = None
87
+ self.stats = None
88
+ self.confusion_matrix = None
89
+ self.nc = None
90
+ self.iouv = None
91
+ self.jdict = None
92
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
93
+
94
+ self.save_dir = save_dir or get_save_dir(self.args)
95
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
96
+ if self.args.conf is None:
97
+ self.args.conf = 0.001 # default conf=0.001
98
+ self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
99
+
100
+ self.plots = {}
101
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
102
+
103
+ @smart_inference_mode()
104
+ def __call__(self, trainer=None, model=None):
105
+ """Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
106
+ gets priority).
107
+ """
108
+ self.training = trainer is not None
109
+ augment = self.args.augment and (not self.training)
110
+ if self.training:
111
+ self.device = trainer.device
112
+ self.data = trainer.data
113
+ self.args.half = self.device.type != "cpu" # force FP16 val during training
114
+ model = trainer.ema.ema or trainer.model
115
+ model = model.half() if self.args.half else model.float()
116
+ # self.model = model
117
+ self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
118
+ self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
119
+ model.eval()
120
+ else:
121
+ callbacks.add_integration_callbacks(self)
122
+ model = AutoBackend(
123
+ model or self.args.model,
124
+ device=select_device(self.args.device, self.args.batch),
125
+ dnn=self.args.dnn,
126
+ data=self.args.data,
127
+ fp16=self.args.half,
128
+ )
129
+ # self.model = model
130
+ self.device = model.device # update device
131
+ self.args.half = model.fp16 # update half
132
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
133
+ imgsz = check_imgsz(self.args.imgsz, stride=stride)
134
+ if engine:
135
+ self.args.batch = model.batch_size
136
+ elif not pt and not jit:
137
+ self.args.batch = 1 # export.py models default to batch-size 1
138
+ LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
139
+
140
+ if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
141
+ self.data = check_det_dataset(self.args.data)
142
+ elif self.args.task == "classify":
143
+ self.data = check_cls_dataset(self.args.data, split=self.args.split)
144
+ else:
145
+ raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
146
+
147
+ if self.device.type in ("cpu", "mps"):
148
+ self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
149
+ if not pt:
150
+ self.args.rect = False
151
+ self.stride = model.stride # used in get_dataloader() for padding
152
+ self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
153
+
154
+ model.eval()
155
+ model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
156
+
157
+ self.run_callbacks("on_val_start")
158
+ dt = (
159
+ Profile(device=self.device),
160
+ Profile(device=self.device),
161
+ Profile(device=self.device),
162
+ Profile(device=self.device),
163
+ )
164
+ bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
165
+ self.init_metrics(de_parallel(model))
166
+ self.jdict = [] # empty before each val
167
+ for batch_i, batch in enumerate(bar):
168
+ self.run_callbacks("on_val_batch_start")
169
+ self.batch_i = batch_i
170
+ # Preprocess
171
+ with dt[0]:
172
+ batch = self.preprocess(batch)
173
+
174
+ # Inference
175
+ with dt[1]:
176
+ preds = model(batch["img"], augment=augment)
177
+
178
+ # Loss
179
+ with dt[2]:
180
+ if self.training:
181
+ self.loss += model.loss(batch, preds)[1]
182
+
183
+ # Postprocess
184
+ with dt[3]:
185
+ preds = self.postprocess(preds)
186
+
187
+ self.update_metrics(preds, batch)
188
+ if self.args.plots and batch_i < 3:
189
+ self.plot_val_samples(batch, batch_i)
190
+ self.plot_predictions(batch, preds, batch_i)
191
+
192
+ self.run_callbacks("on_val_batch_end")
193
+ stats = self.get_stats()
194
+ self.check_stats(stats)
195
+ self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
196
+ self.finalize_metrics()
197
+ self.print_results()
198
+ self.run_callbacks("on_val_end")
199
+ if self.training:
200
+ model.float()
201
+ results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
202
+ return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
203
+ else:
204
+ LOGGER.info(
205
+ "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
206
+ % tuple(self.speed.values())
207
+ )
208
+ if self.args.save_json and self.jdict:
209
+ with open(str(self.save_dir / "predictions.json"), "w") as f:
210
+ LOGGER.info(f"Saving {f.name}...")
211
+ json.dump(self.jdict, f) # flatten and save
212
+ stats = self.eval_json(stats) # update stats
213
+ if self.args.plots or self.args.save_json:
214
+ LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
215
+ return stats
216
+
217
+ def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
218
+ """
219
+ Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
220
+
221
+ Args:
222
+ pred_classes (torch.Tensor): Predicted class indices of shape(N,).
223
+ true_classes (torch.Tensor): Target class indices of shape(M,).
224
+ iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
225
+ use_scipy (bool): Whether to use scipy for matching (more precise).
226
+
227
+ Returns:
228
+ (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
229
+ """
230
+ # Dx10 matrix, where D - detections, 10 - IoU thresholds
231
+ correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
232
+ # LxD matrix where L - labels (rows), D - detections (columns)
233
+ correct_class = true_classes[:, None] == pred_classes
234
+ iou = iou * correct_class # zero out the wrong classes
235
+ iou = iou.cpu().numpy()
236
+ for i, threshold in enumerate(self.iouv.cpu().tolist()):
237
+ if use_scipy:
238
+ # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
239
+ import scipy # scope import to avoid importing for all commands
240
+
241
+ cost_matrix = iou * (iou >= threshold)
242
+ if cost_matrix.any():
243
+ labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
244
+ valid = cost_matrix[labels_idx, detections_idx] > 0
245
+ if valid.any():
246
+ correct[detections_idx[valid], i] = True
247
+ else:
248
+ matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
249
+ matches = np.array(matches).T
250
+ if matches.shape[0]:
251
+ if matches.shape[0] > 1:
252
+ matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
253
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
254
+ # matches = matches[matches[:, 2].argsort()[::-1]]
255
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
256
+ correct[matches[:, 1].astype(int), i] = True
257
+ return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
258
+
259
+ def add_callback(self, event: str, callback):
260
+ """Appends the given callback."""
261
+ self.callbacks[event].append(callback)
262
+
263
+ def run_callbacks(self, event: str):
264
+ """Runs all callbacks associated with a specified event."""
265
+ for callback in self.callbacks.get(event, []):
266
+ callback(self)
267
+
268
+ def get_dataloader(self, dataset_path, batch_size):
269
+ """Get data loader from dataset path and batch size."""
270
+ raise NotImplementedError("get_dataloader function not implemented for this validator")
271
+
272
+ def build_dataset(self, img_path):
273
+ """Build dataset."""
274
+ raise NotImplementedError("build_dataset function not implemented in validator")
275
+
276
+ def preprocess(self, batch):
277
+ """Preprocesses an input batch."""
278
+ return batch
279
+
280
+ def postprocess(self, preds):
281
+ """Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
282
+ return preds
283
+
284
+ def init_metrics(self, model):
285
+ """Initialize performance metrics for the YOLO model."""
286
+ pass
287
+
288
+ def update_metrics(self, preds, batch):
289
+ """Updates metrics based on predictions and batch."""
290
+ pass
291
+
292
+ def finalize_metrics(self, *args, **kwargs):
293
+ """Finalizes and returns all metrics."""
294
+ pass
295
+
296
+ def get_stats(self):
297
+ """Returns statistics about the model's performance."""
298
+ return {}
299
+
300
+ def check_stats(self, stats):
301
+ """Checks statistics."""
302
+ pass
303
+
304
+ def print_results(self):
305
+ """Prints the results of the model's predictions."""
306
+ pass
307
+
308
+ def get_desc(self):
309
+ """Get description of the YOLO model."""
310
+ pass
311
+
312
+ @property
313
+ def metric_keys(self):
314
+ """Returns the metric keys used in YOLO training/validation."""
315
+ return []
316
+
317
+ def on_plot(self, name, data=None):
318
+ """Registers plots (e.g. to be consumed in callbacks)"""
319
+ self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
320
+
321
+ # TODO: may need to put these following functions into callback
322
+ def plot_val_samples(self, batch, ni):
323
+ """Plots validation samples during training."""
324
+ pass
325
+
326
+ def plot_predictions(self, batch, preds, ni):
327
+ """Plots YOLO model predictions on batch images."""
328
+ pass
329
+
330
+ def pred_to_json(self, preds, batch):
331
+ """Convert predictions to JSON format."""
332
+ pass
333
+
334
+ def eval_json(self, stats):
335
+ """Evaluate and return JSON format of prediction statistics."""
336
+ pass
yolov8_model/ultralytics/hub/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import requests
4
+
5
+ from yolov8_model.ultralytics.data.utils import HUBDatasetStats
6
+ from yolov8_model.ultralytics.hub.auth import Auth
7
+ from yolov8_model.ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
8
+ from yolov8_model.ultralytics.utils import LOGGER, SETTINGS, checks
9
+
10
+
11
+ def login(api_key: str = None, save=True) -> bool:
12
+ """
13
+ Log in to the Ultralytics HUB API using the provided API key.
14
+
15
+ The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
16
+ environment variable if successfully authenticated.
17
+
18
+ Args:
19
+ api_key (str, optional): API key to use for authentication.
20
+ If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
21
+ save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
22
+
23
+ Returns:
24
+ (bool): True if authentication is successful, False otherwise.
25
+ """
26
+ checks.check_requirements("hub-sdk>=0.0.2")
27
+ from hub_sdk import HUBClient
28
+
29
+ api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
30
+ saved_key = SETTINGS.get("api_key")
31
+ active_key = api_key or saved_key
32
+ credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
33
+
34
+ client = HUBClient(credentials) # initialize HUBClient
35
+
36
+ if client.authenticated:
37
+ # Successfully authenticated with HUB
38
+
39
+ if save and client.api_key != saved_key:
40
+ SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
41
+
42
+ # Set message based on whether key was provided or retrieved from settings
43
+ log_message = (
44
+ "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
45
+ )
46
+ LOGGER.info(f"{PREFIX}{log_message}")
47
+
48
+ return True
49
+ else:
50
+ # Failed to authenticate with HUB
51
+ LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}")
52
+ return False
53
+
54
+
55
+ def logout():
56
+ """
57
+ Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
58
+
59
+ Example:
60
+ ```python
61
+ from ultralytics import hub
62
+
63
+ hub.logout()
64
+ ```
65
+ """
66
+ SETTINGS["api_key"] = ""
67
+ SETTINGS.save()
68
+ LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
69
+
70
+
71
+ def reset_model(model_id=""):
72
+ """Reset a trained model to an untrained state."""
73
+ r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
74
+ if r.status_code == 200:
75
+ LOGGER.info(f"{PREFIX}Model reset successfully")
76
+ return
77
+ LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
78
+
79
+
80
+ def export_fmts_hub():
81
+ """Returns a list of HUB-supported export formats."""
82
+ from ultralytics.engine.exporter import export_formats
83
+
84
+ return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
85
+
86
+
87
+ def export_model(model_id="", format="torchscript"):
88
+ """Export a model to all formats."""
89
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
90
+ r = requests.post(
91
+ f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
92
+ )
93
+ assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
94
+ LOGGER.info(f"{PREFIX}{format} export started ✅")
95
+
96
+
97
+ def get_export(model_id="", format="torchscript"):
98
+ """Get an exported model dictionary with download URL."""
99
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
100
+ r = requests.post(
101
+ f"{HUB_API_ROOT}/get-export",
102
+ json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
103
+ headers={"x-api-key": Auth().api_key},
104
+ )
105
+ assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
106
+ return r.json()
107
+
108
+
109
+ def check_dataset(path="", task="detect"):
110
+ """
111
+ Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
112
+ to the HUB. Usage examples are given below.
113
+
114
+ Args:
115
+ path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
116
+ task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'.
117
+
118
+ Example:
119
+ ```python
120
+ from ultralytics.hub import check_dataset
121
+
122
+ check_dataset('path/to/coco8.zip', task='detect') # detect dataset
123
+ check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
124
+ check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
125
+ ```
126
+ """
127
+ HUBDatasetStats(path=path, task=task).get_json()
128
+ LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
yolov8_model/ultralytics/hub/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.02 kB). View file
 
yolov8_model/ultralytics/hub/__pycache__/auth.cpython-310.pyc ADDED
Binary file (4.32 kB). View file
 
yolov8_model/ultralytics/hub/__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.53 kB). View file
 
yolov8_model/ultralytics/hub/auth.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import requests
4
+
5
+ from yolov8_model.ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
6
+ from yolov8_model.ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
7
+
8
+ API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
9
+
10
+
11
+ class Auth:
12
+ """
13
+ Manages authentication processes including API key handling, cookie-based authentication, and header generation.
14
+
15
+ The class supports different methods of authentication:
16
+ 1. Directly using an API key.
17
+ 2. Authenticating using browser cookies (specifically in Google Colab).
18
+ 3. Prompting the user to enter an API key.
19
+
20
+ Attributes:
21
+ id_token (str or bool): Token used for identity verification, initialized as False.
22
+ api_key (str or bool): API key for authentication, initialized as False.
23
+ model_key (bool): Placeholder for model key, initialized as False.
24
+ """
25
+
26
+ id_token = api_key = model_key = False
27
+
28
+ def __init__(self, api_key="", verbose=False):
29
+ """
30
+ Initialize the Auth class with an optional API key.
31
+
32
+ Args:
33
+ api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
34
+ """
35
+ # Split the input API key in case it contains a combined key_model and keep only the API key part
36
+ api_key = api_key.split("_")[0]
37
+
38
+ # Set API key attribute as value passed or SETTINGS API key if none passed
39
+ self.api_key = api_key or SETTINGS.get("api_key", "")
40
+
41
+ # If an API key is provided
42
+ if self.api_key:
43
+ # If the provided API key matches the API key in the SETTINGS
44
+ if self.api_key == SETTINGS.get("api_key"):
45
+ # Log that the user is already logged in
46
+ if verbose:
47
+ LOGGER.info(f"{PREFIX}Authenticated ✅")
48
+ return
49
+ else:
50
+ # Attempt to authenticate with the provided API key
51
+ success = self.authenticate()
52
+ # If the API key is not provided and the environment is a Google Colab notebook
53
+ elif is_colab():
54
+ # Attempt to authenticate using browser cookies
55
+ success = self.auth_with_cookies()
56
+ else:
57
+ # Request an API key
58
+ success = self.request_api_key()
59
+
60
+ # Update SETTINGS with the new API key after successful authentication
61
+ if success:
62
+ SETTINGS.update({"api_key": self.api_key})
63
+ # Log that the new login was successful
64
+ if verbose:
65
+ LOGGER.info(f"{PREFIX}New authentication successful ✅")
66
+ elif verbose:
67
+ LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}")
68
+
69
+ def request_api_key(self, max_attempts=3):
70
+ """
71
+ Prompt the user to input their API key.
72
+
73
+ Returns the model ID.
74
+ """
75
+ import getpass
76
+
77
+ for attempts in range(max_attempts):
78
+ LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
79
+ input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
80
+ self.api_key = input_key.split("_")[0] # remove model id if present
81
+ if self.authenticate():
82
+ return True
83
+ raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
84
+
85
+ def authenticate(self) -> bool:
86
+ """
87
+ Attempt to authenticate with the server using either id_token or API key.
88
+
89
+ Returns:
90
+ (bool): True if authentication is successful, False otherwise.
91
+ """
92
+ try:
93
+ if header := self.get_auth_header():
94
+ r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
95
+ if not r.json().get("success", False):
96
+ raise ConnectionError("Unable to authenticate.")
97
+ return True
98
+ raise ConnectionError("User has not authenticated locally.")
99
+ except ConnectionError:
100
+ self.id_token = self.api_key = False # reset invalid
101
+ LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
102
+ return False
103
+
104
+ def auth_with_cookies(self) -> bool:
105
+ """
106
+ Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
107
+ supported browser.
108
+
109
+ Returns:
110
+ (bool): True if authentication is successful, False otherwise.
111
+ """
112
+ if not is_colab():
113
+ return False # Currently only works with Colab
114
+ try:
115
+ authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
116
+ if authn.get("success", False):
117
+ self.id_token = authn.get("data", {}).get("idToken", None)
118
+ self.authenticate()
119
+ return True
120
+ raise ConnectionError("Unable to fetch browser authentication details.")
121
+ except ConnectionError:
122
+ self.id_token = False # reset invalid
123
+ return False
124
+
125
+ def get_auth_header(self):
126
+ """
127
+ Get the authentication header for making API requests.
128
+
129
+ Returns:
130
+ (dict): The authentication header if id_token or API key is set, None otherwise.
131
+ """
132
+ if self.id_token:
133
+ return {"authorization": f"Bearer {self.id_token}"}
134
+ elif self.api_key:
135
+ return {"x-api-key": self.api_key}
136
+ # else returns None
yolov8_model/ultralytics/hub/session.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import threading
4
+ import time
5
+ from http import HTTPStatus
6
+ from pathlib import Path
7
+
8
+ import requests
9
+
10
+ from yolov8_model.ultralytics.hub.utils import HUB_WEB_ROOT, HELP_MSG, PREFIX, TQDM
11
+ from yolov8_model.ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
12
+ from yolov8_model.ultralytics.utils.errors import HUBModelError
13
+
14
+ AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
15
+
16
+
17
+ class HUBTrainingSession:
18
+ """
19
+ HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
20
+
21
+ Attributes:
22
+ agent_id (str): Identifier for the instance communicating with the server.
23
+ model_id (str): Identifier for the YOLO model being trained.
24
+ model_url (str): URL for the model in Ultralytics HUB.
25
+ api_url (str): API URL for the model in Ultralytics HUB.
26
+ auth_header (dict): Authentication header for the Ultralytics HUB API requests.
27
+ rate_limits (dict): Rate limits for different API calls (in seconds).
28
+ timers (dict): Timers for rate limiting.
29
+ metrics_queue (dict): Queue for the model's metrics.
30
+ model (dict): Model data fetched from Ultralytics HUB.
31
+ alive (bool): Indicates if the heartbeat loop is active.
32
+ """
33
+
34
+ def __init__(self, identifier):
35
+ """
36
+ Initialize the HUBTrainingSession with the provided model identifier.
37
+
38
+ Args:
39
+ identifier (str): Model identifier used to initialize the HUB training session.
40
+ It can be a URL string or a model key with specific format.
41
+
42
+ Raises:
43
+ ValueError: If the provided model identifier is invalid.
44
+ ConnectionError: If connecting with global API key is not supported.
45
+ ModuleNotFoundError: If hub-sdk package is not installed.
46
+ """
47
+ from hub_sdk import HUBClient
48
+
49
+ self.rate_limits = {
50
+ "metrics": 3.0,
51
+ "ckpt": 900.0,
52
+ "heartbeat": 300.0,
53
+ } # rate limits (seconds)
54
+ self.metrics_queue = {} # holds metrics for each epoch until upload
55
+ self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
56
+
57
+ # Parse input
58
+ api_key, model_id, self.filename = self._parse_identifier(identifier)
59
+
60
+ # Get credentials
61
+ active_key = api_key or SETTINGS.get("api_key")
62
+ credentials = {"api_key": active_key} if active_key else None # set credentials
63
+
64
+ # Initialize client
65
+ self.client = HUBClient(credentials)
66
+
67
+ if model_id:
68
+ self.load_model(model_id) # load existing model
69
+ else:
70
+ self.model = self.client.model() # load empty model
71
+
72
+ def load_model(self, model_id):
73
+ """Loads an existing model from Ultralytics HUB using the provided model identifier."""
74
+ self.model = self.client.model(model_id)
75
+ if not self.model.data: # then model does not exist
76
+ raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
77
+
78
+ self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
79
+
80
+ self._set_train_args()
81
+
82
+ # Start heartbeats for HUB to monitor agent
83
+ self.model.start_heartbeat(self.rate_limits["heartbeat"])
84
+ LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
85
+
86
+ def create_model(self, model_args):
87
+ """Initializes a HUB training session with the specified model identifier."""
88
+ payload = {
89
+ "config": {
90
+ "batchSize": model_args.get("batch", -1),
91
+ "epochs": model_args.get("epochs", 300),
92
+ "imageSize": model_args.get("imgsz", 640),
93
+ "patience": model_args.get("patience", 100),
94
+ "device": model_args.get("device", ""),
95
+ "cache": model_args.get("cache", "ram"),
96
+ },
97
+ "dataset": {"name": model_args.get("data")},
98
+ "lineage": {
99
+ "architecture": {
100
+ "name": self.filename.replace(".pt", "").replace(".yaml", ""),
101
+ },
102
+ "parent": {},
103
+ },
104
+ "meta": {"name": self.filename},
105
+ }
106
+
107
+ if self.filename.endswith(".pt"):
108
+ payload["lineage"]["parent"]["name"] = self.filename
109
+
110
+ self.model.create_model(payload)
111
+
112
+ # Model could not be created
113
+ # TODO: improve error handling
114
+ if not self.model.id:
115
+ return
116
+
117
+ self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
118
+
119
+ # Start heartbeats for HUB to monitor agent
120
+ self.model.start_heartbeat(self.rate_limits["heartbeat"])
121
+
122
+ LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
123
+
124
+ def _parse_identifier(self, identifier):
125
+ """
126
+ Parses the given identifier to determine the type of identifier and extract relevant components.
127
+
128
+ The method supports different identifier formats:
129
+ - A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
130
+ - An identifier containing an API key and a model ID separated by an underscore
131
+ - An identifier that is solely a model ID of a fixed length
132
+ - A local filename that ends with '.pt' or '.yaml'
133
+
134
+ Args:
135
+ identifier (str): The identifier string to be parsed.
136
+
137
+ Returns:
138
+ (tuple): A tuple containing the API key, model ID, and filename as applicable.
139
+
140
+ Raises:
141
+ HUBModelError: If the identifier format is not recognized.
142
+ """
143
+
144
+ # Initialize variables
145
+ api_key, model_id, filename = None, None, None
146
+
147
+ # Check if identifier is a HUB URL
148
+ if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
149
+ # Extract the model_id after the HUB_WEB_ROOT URL
150
+ model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
151
+ else:
152
+ # Split the identifier based on underscores only if it's not a HUB URL
153
+ parts = identifier.split("_")
154
+
155
+ # Check if identifier is in the format of API key and model ID
156
+ if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
157
+ api_key, model_id = parts
158
+ # Check if identifier is a single model ID
159
+ elif len(parts) == 1 and len(parts[0]) == 20:
160
+ model_id = parts[0]
161
+ # Check if identifier is a local filename
162
+ elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
163
+ filename = identifier
164
+ else:
165
+ raise HUBModelError(
166
+ f"model='{identifier}' could not be parsed. Check format is correct. "
167
+ f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
168
+ )
169
+
170
+ return api_key, model_id, filename
171
+
172
+ def _set_train_args(self, **kwargs):
173
+ """Initializes training arguments and creates a model entry on the Ultralytics HUB."""
174
+ if self.model.is_trained():
175
+ # Model is already trained
176
+ raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
177
+
178
+ if self.model.is_resumable():
179
+ # Model has saved weights
180
+ self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
181
+ self.model_file = self.model.get_weights_url("last")
182
+ else:
183
+ # Model has no saved weights
184
+ def get_train_args(config):
185
+ """Parses an identifier to extract API key, model ID, and filename if applicable."""
186
+ return {
187
+ "batch": config["batchSize"],
188
+ "epochs": config["epochs"],
189
+ "imgsz": config["imageSize"],
190
+ "patience": config["patience"],
191
+ "device": config["device"],
192
+ "cache": config["cache"],
193
+ "data": self.model.get_dataset_url(),
194
+ }
195
+
196
+ self.train_args = get_train_args(self.model.data.get("config"))
197
+ # Set the model file as either a *.pt or *.yaml file
198
+ self.model_file = (
199
+ self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
200
+ )
201
+
202
+ if not self.train_args.get("data"):
203
+ raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
204
+
205
+ self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
206
+ self.model_id = self.model.id
207
+
208
+ def request_queue(
209
+ self,
210
+ request_func,
211
+ retry=3,
212
+ timeout=30,
213
+ thread=True,
214
+ verbose=True,
215
+ progress_total=None,
216
+ *args,
217
+ **kwargs,
218
+ ):
219
+ def retry_request():
220
+ """Attempts to call `request_func` with retries, timeout, and optional threading."""
221
+ t0 = time.time() # Record the start time for the timeout
222
+ for i in range(retry + 1):
223
+ if (time.time() - t0) > timeout:
224
+ LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
225
+ break # Timeout reached, exit loop
226
+
227
+ response = request_func(*args, **kwargs)
228
+ if response is None:
229
+ LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
230
+ time.sleep(2**i) # Exponential backoff before retrying
231
+ continue # Skip further processing and retry
232
+
233
+ if progress_total:
234
+ self._show_upload_progress(progress_total, response)
235
+
236
+ if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
237
+ return response # Success, no need to retry
238
+
239
+ if i == 0:
240
+ # Initial attempt, check status code and provide messages
241
+ message = self._get_failure_message(response, retry, timeout)
242
+
243
+ if verbose:
244
+ LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
245
+
246
+ if not self._should_retry(response.status_code):
247
+ LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
248
+ break # Not an error that should be retried, exit loop
249
+
250
+ time.sleep(2**i) # Exponential backoff for retries
251
+
252
+ return response
253
+
254
+ if thread:
255
+ # Start a new thread to run the retry_request function
256
+ threading.Thread(target=retry_request, daemon=True).start()
257
+ else:
258
+ # If running in the main thread, call retry_request directly
259
+ return retry_request()
260
+
261
+ def _should_retry(self, status_code):
262
+ """Determines if a request should be retried based on the HTTP status code."""
263
+ retry_codes = {
264
+ HTTPStatus.REQUEST_TIMEOUT,
265
+ HTTPStatus.BAD_GATEWAY,
266
+ HTTPStatus.GATEWAY_TIMEOUT,
267
+ }
268
+ return status_code in retry_codes
269
+
270
+ def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
271
+ """
272
+ Generate a retry message based on the response status code.
273
+
274
+ Args:
275
+ response: The HTTP response object.
276
+ retry: The number of retry attempts allowed.
277
+ timeout: The maximum timeout duration.
278
+
279
+ Returns:
280
+ (str): The retry message.
281
+ """
282
+ if self._should_retry(response.status_code):
283
+ return f"Retrying {retry}x for {timeout}s." if retry else ""
284
+ elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
285
+ headers = response.headers
286
+ return (
287
+ f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
288
+ f"Please retry after {headers['Retry-After']}s."
289
+ )
290
+ else:
291
+ try:
292
+ return response.json().get("message", "No JSON message.")
293
+ except AttributeError:
294
+ return "Unable to read JSON."
295
+
296
+ def upload_metrics(self):
297
+ """Upload model metrics to Ultralytics HUB."""
298
+ return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
299
+
300
+ def upload_model(
301
+ self,
302
+ epoch: int,
303
+ weights: str,
304
+ is_best: bool = False,
305
+ map: float = 0.0,
306
+ final: bool = False,
307
+ ) -> None:
308
+ """
309
+ Upload a model checkpoint to Ultralytics HUB.
310
+
311
+ Args:
312
+ epoch (int): The current training epoch.
313
+ weights (str): Path to the model weights file.
314
+ is_best (bool): Indicates if the current model is the best one so far.
315
+ map (float): Mean average precision of the model.
316
+ final (bool): Indicates if the model is the final model after training.
317
+ """
318
+ if Path(weights).is_file():
319
+ progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
320
+ self.request_queue(
321
+ self.model.upload_model,
322
+ epoch=epoch,
323
+ weights=weights,
324
+ is_best=is_best,
325
+ map=map,
326
+ final=final,
327
+ retry=10,
328
+ timeout=3600,
329
+ thread=not final,
330
+ progress_total=progress_total,
331
+ )
332
+ else:
333
+ LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
334
+
335
+ def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
336
+ """
337
+ Display a progress bar to track the upload progress of a file download.
338
+
339
+ Args:
340
+ content_length (int): The total size of the content to be downloaded in bytes.
341
+ response (requests.Response): The response object from the file download request.
342
+
343
+ Returns:
344
+ None
345
+ """
346
+ with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
347
+ for data in response.iter_content(chunk_size=1024):
348
+ pbar.update(len(data))
yolov8_model/ultralytics/hub/utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import os
4
+ import platform
5
+ import random
6
+ import sys
7
+ import threading
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import requests
12
+
13
+ from yolov8_model.ultralytics.utils import (
14
+ ENVIRONMENT,
15
+ LOGGER,
16
+ ONLINE,
17
+ RANK,
18
+ SETTINGS,
19
+ TESTS_RUNNING,
20
+ TQDM,
21
+ TryExcept,
22
+ __version__,
23
+ colorstr,
24
+ get_git_origin_url,
25
+ is_colab,
26
+ is_git_dir,
27
+ is_pip_package,
28
+ )
29
+ from yolov8_model.ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
30
+
31
+ HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
32
+ HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
33
+
34
+ PREFIX = colorstr("Ultralytics HUB: ")
35
+ HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
36
+
37
+
38
+ def request_with_credentials(url: str) -> any:
39
+ """
40
+ Make an AJAX request with cookies attached in a Google Colab environment.
41
+
42
+ Args:
43
+ url (str): The URL to make the request to.
44
+
45
+ Returns:
46
+ (any): The response data from the AJAX request.
47
+
48
+ Raises:
49
+ OSError: If the function is not run in a Google Colab environment.
50
+ """
51
+ if not is_colab():
52
+ raise OSError("request_with_credentials() must run in a Colab environment")
53
+ from google.colab import output # noqa
54
+ from IPython import display # noqa
55
+
56
+ display.display(
57
+ display.Javascript(
58
+ """
59
+ window._hub_tmp = new Promise((resolve, reject) => {
60
+ const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
61
+ fetch("%s", {
62
+ method: 'POST',
63
+ credentials: 'include'
64
+ })
65
+ .then((response) => resolve(response.json()))
66
+ .then((json) => {
67
+ clearTimeout(timeout);
68
+ }).catch((err) => {
69
+ clearTimeout(timeout);
70
+ reject(err);
71
+ });
72
+ });
73
+ """
74
+ % url
75
+ )
76
+ )
77
+ return output.eval_js("_hub_tmp")
78
+
79
+
80
+ def requests_with_progress(method, url, **kwargs):
81
+ """
82
+ Make an HTTP request using the specified method and URL, with an optional progress bar.
83
+
84
+ Args:
85
+ method (str): The HTTP method to use (e.g. 'GET', 'POST').
86
+ url (str): The URL to send the request to.
87
+ **kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
88
+
89
+ Returns:
90
+ (requests.Response): The response object from the HTTP request.
91
+
92
+ Note:
93
+ - If 'progress' is set to True, the progress bar will display the download progress for responses with a known
94
+ content length.
95
+ - If 'progress' is a number then progress bar will display assuming content length = progress.
96
+ """
97
+ progress = kwargs.pop("progress", False)
98
+ if not progress:
99
+ return requests.request(method, url, **kwargs)
100
+ response = requests.request(method, url, stream=True, **kwargs)
101
+ total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
102
+ try:
103
+ pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
104
+ for data in response.iter_content(chunk_size=1024):
105
+ pbar.update(len(data))
106
+ pbar.close()
107
+ except requests.exceptions.ChunkedEncodingError: # avoid 'Connection broken: IncompleteRead' warnings
108
+ response.close()
109
+ return response
110
+
111
+
112
+ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
113
+ """
114
+ Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
115
+
116
+ Args:
117
+ method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
118
+ url (str): The URL to make the request to.
119
+ retry (int, optional): Number of retries to attempt before giving up. Default is 3.
120
+ timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
121
+ thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
122
+ code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
123
+ verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
124
+ progress (bool, optional): Whether to show a progress bar during the request. Default is False.
125
+ **kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
126
+
127
+ Returns:
128
+ (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
129
+ """
130
+ retry_codes = (408, 500) # retry only these codes
131
+
132
+ @TryExcept(verbose=verbose)
133
+ def func(func_method, func_url, **func_kwargs):
134
+ """Make HTTP requests with retries and timeouts, with optional progress tracking."""
135
+ r = None # response
136
+ t0 = time.time() # initial time for timer
137
+ for i in range(retry + 1):
138
+ if (time.time() - t0) > timeout:
139
+ break
140
+ r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
141
+ if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
142
+ break
143
+ try:
144
+ m = r.json().get("message", "No JSON message.")
145
+ except AttributeError:
146
+ m = "Unable to read JSON."
147
+ if i == 0:
148
+ if r.status_code in retry_codes:
149
+ m += f" Retrying {retry}x for {timeout}s." if retry else ""
150
+ elif r.status_code == 429: # rate limit
151
+ h = r.headers # response headers
152
+ m = (
153
+ f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
154
+ f"Please retry after {h['Retry-After']}s."
155
+ )
156
+ if verbose:
157
+ LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
158
+ if r.status_code not in retry_codes:
159
+ return r
160
+ time.sleep(2**i) # exponential standoff
161
+ return r
162
+
163
+ args = method, url
164
+ kwargs["progress"] = progress
165
+ if thread:
166
+ threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
167
+ else:
168
+ return func(*args, **kwargs)
169
+
170
+
171
+ class Events:
172
+ """
173
+ A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
174
+ disabled when sync=False. Run 'yolo settings' to see and update settings YAML file.
175
+
176
+ Attributes:
177
+ url (str): The URL to send anonymous events.
178
+ rate_limit (float): The rate limit in seconds for sending events.
179
+ metadata (dict): A dictionary containing metadata about the environment.
180
+ enabled (bool): A flag to enable or disable Events based on certain conditions.
181
+ """
182
+
183
+ url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
184
+
185
+ def __init__(self):
186
+ """Initializes the Events object with default values for events, rate_limit, and metadata."""
187
+ self.events = [] # events list
188
+ self.rate_limit = 60.0 # rate limit (seconds)
189
+ self.t = 0.0 # rate limit timer (seconds)
190
+ self.metadata = {
191
+ "cli": Path(sys.argv[0]).name == "yolo",
192
+ "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
193
+ "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
194
+ "version": __version__,
195
+ "env": ENVIRONMENT,
196
+ "session_id": round(random.random() * 1e15),
197
+ "engagement_time_msec": 1000,
198
+ }
199
+ self.enabled = (
200
+ SETTINGS["sync"]
201
+ and RANK in (-1, 0)
202
+ and not TESTS_RUNNING
203
+ and ONLINE
204
+ and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
205
+ )
206
+
207
+ def __call__(self, cfg):
208
+ """
209
+ Attempts to add a new event to the events list and send events if the rate limit is reached.
210
+
211
+ Args:
212
+ cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
213
+ """
214
+ if not self.enabled:
215
+ # Events disabled, do nothing
216
+ return
217
+
218
+ # Attempt to add to events
219
+ if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
220
+ params = {
221
+ **self.metadata,
222
+ "task": cfg.task,
223
+ "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
224
+ }
225
+ if cfg.mode == "export":
226
+ params["format"] = cfg.format
227
+ self.events.append({"name": cfg.mode, "params": params})
228
+
229
+ # Check rate limit
230
+ t = time.time()
231
+ if (t - self.t) < self.rate_limit:
232
+ # Time is under rate limiter, wait to send
233
+ return
234
+
235
+ # Time is over rate limiter, send now
236
+ data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
237
+
238
+ # POST equivalent to requests.post(self.url, json=data)
239
+ smart_request("post", self.url, json=data, retry=0, verbose=False)
240
+
241
+ # Reset events and rate limit timer
242
+ self.events = []
243
+ self.t = t
244
+
245
+
246
+ # Run below code on hub/utils init -------------------------------------------------------------------------------------
247
+ events = Events()
yolov8_model/ultralytics/models/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .rtdetr import RTDETR
4
+ from .sam import SAM
5
+ from .yolo import YOLO
6
+
7
+ __all__ = "YOLO", "RTDETR", "SAM" # allow simpler import
yolov8_model/ultralytics/models/fastsam/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .model import FastSAM
4
+ from .predict import FastSAMPredictor
5
+ from .prompt import FastSAMPrompt
6
+ from .val import FastSAMValidator
7
+
8
+ __all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator"
yolov8_model/ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (390 Bytes). View file