Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- yolov8_model/ultralytics/data/__init__.py +15 -0
- yolov8_model/ultralytics/data/__pycache__/__init__.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/__pycache__/augment.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/__pycache__/base.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/__pycache__/build.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/__pycache__/converter.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/__pycache__/dataset.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/__pycache__/loaders.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/__pycache__/utils.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/dataset.py +375 -0
- yolov8_model/ultralytics/data/explorer/__init__.py +5 -0
- yolov8_model/ultralytics/data/explorer/__pycache__/__init__.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/explorer/__pycache__/explorer.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/explorer/__pycache__/utils.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/data/explorer/explorer.py +471 -0
- yolov8_model/ultralytics/data/explorer/gui/__init__.py +1 -0
- yolov8_model/ultralytics/data/explorer/gui/dash.py +268 -0
- yolov8_model/ultralytics/data/explorer/utils.py +166 -0
- yolov8_model/ultralytics/data/loaders.py +533 -0
- yolov8_model/ultralytics/data/scripts/download_weights.sh +18 -0
- yolov8_model/ultralytics/data/scripts/get_coco.sh +60 -0
- yolov8_model/ultralytics/data/scripts/get_coco128.sh +17 -0
- yolov8_model/ultralytics/data/scripts/get_imagenet.sh +51 -0
- yolov8_model/ultralytics/data/split_dota.py +288 -0
- yolov8_model/ultralytics/data/utils.py +647 -0
- yolov8_model/ultralytics/engine/__init__.py +1 -0
- yolov8_model/ultralytics/engine/__pycache__/__init__.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/engine/__pycache__/exporter.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/engine/__pycache__/model.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/engine/__pycache__/predictor.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/engine/__pycache__/results.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/engine/__pycache__/trainer.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/engine/__pycache__/validator.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/engine/exporter.py +1099 -0
- yolov8_model/ultralytics/engine/model.py +772 -0
- yolov8_model/ultralytics/engine/predictor.py +407 -0
- yolov8_model/ultralytics/engine/results.py +680 -0
- yolov8_model/ultralytics/engine/trainer.py +755 -0
- yolov8_model/ultralytics/engine/tuner.py +240 -0
- yolov8_model/ultralytics/engine/validator.py +336 -0
- yolov8_model/ultralytics/hub/__init__.py +128 -0
- yolov8_model/ultralytics/hub/__pycache__/__init__.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/hub/__pycache__/auth.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/hub/__pycache__/utils.cpython-310.pyc +0 -0
- yolov8_model/ultralytics/hub/auth.py +136 -0
- yolov8_model/ultralytics/hub/session.py +348 -0
- yolov8_model/ultralytics/hub/utils.py +247 -0
- yolov8_model/ultralytics/models/__init__.py +7 -0
- yolov8_model/ultralytics/models/fastsam/__init__.py +8 -0
- yolov8_model/ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc +0 -0
yolov8_model/ultralytics/data/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from .base import BaseDataset
|
| 4 |
+
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
| 5 |
+
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
| 6 |
+
|
| 7 |
+
__all__ = (
|
| 8 |
+
"BaseDataset",
|
| 9 |
+
"ClassificationDataset",
|
| 10 |
+
"SemanticDataset",
|
| 11 |
+
"YOLODataset",
|
| 12 |
+
"build_yolo_dataset",
|
| 13 |
+
"build_dataloader",
|
| 14 |
+
"load_inference_source",
|
| 15 |
+
)
|
yolov8_model/ultralytics/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (470 Bytes). View file
|
|
|
yolov8_model/ultralytics/data/__pycache__/augment.cpython-310.pyc
ADDED
|
Binary file (44.5 kB). View file
|
|
|
yolov8_model/ultralytics/data/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
yolov8_model/ultralytics/data/__pycache__/build.cpython-310.pyc
ADDED
|
Binary file (6.24 kB). View file
|
|
|
yolov8_model/ultralytics/data/__pycache__/converter.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
yolov8_model/ultralytics/data/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
yolov8_model/ultralytics/data/__pycache__/loaders.cpython-310.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
yolov8_model/ultralytics/data/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
yolov8_model/ultralytics/data/dataset.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
import contextlib
|
| 3 |
+
from itertools import repeat
|
| 4 |
+
from multiprocessing.pool import ThreadPool
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
from yolov8_model.ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
|
| 14 |
+
from yolov8_model.ultralytics.utils.ops import resample_segments
|
| 15 |
+
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
|
| 16 |
+
from .base import BaseDataset
|
| 17 |
+
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
| 18 |
+
|
| 19 |
+
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
| 20 |
+
DATASET_CACHE_VERSION = "1.0.3"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class YOLODataset(BaseDataset):
|
| 24 |
+
"""
|
| 25 |
+
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
| 29 |
+
task (str): An explicit arg to point current task, Defaults to 'detect'.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, *args, data=None, task="detect", **kwargs):
|
| 36 |
+
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
|
| 37 |
+
self.use_segments = task == "segment"
|
| 38 |
+
self.use_keypoints = task == "pose"
|
| 39 |
+
self.use_obb = task == "obb"
|
| 40 |
+
self.data = data
|
| 41 |
+
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
| 42 |
+
super().__init__(*args, **kwargs)
|
| 43 |
+
|
| 44 |
+
def cache_labels(self, path=Path("./labels.cache")):
|
| 45 |
+
"""
|
| 46 |
+
Cache dataset labels, check images and read shapes.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
path (Path): Path where to save the cache file. Default is Path('./labels.cache').
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
(dict): labels.
|
| 53 |
+
"""
|
| 54 |
+
x = {"labels": []}
|
| 55 |
+
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
| 56 |
+
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
| 57 |
+
total = len(self.im_files)
|
| 58 |
+
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
| 59 |
+
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
| 60 |
+
raise ValueError(
|
| 61 |
+
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
| 62 |
+
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
| 63 |
+
)
|
| 64 |
+
with ThreadPool(NUM_THREADS) as pool:
|
| 65 |
+
results = pool.imap(
|
| 66 |
+
func=verify_image_label,
|
| 67 |
+
iterable=zip(
|
| 68 |
+
self.im_files,
|
| 69 |
+
self.label_files,
|
| 70 |
+
repeat(self.prefix),
|
| 71 |
+
repeat(self.use_keypoints),
|
| 72 |
+
repeat(len(self.data["names"])),
|
| 73 |
+
repeat(nkpt),
|
| 74 |
+
repeat(ndim),
|
| 75 |
+
),
|
| 76 |
+
)
|
| 77 |
+
pbar = TQDM(results, desc=desc, total=total)
|
| 78 |
+
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
| 79 |
+
nm += nm_f
|
| 80 |
+
nf += nf_f
|
| 81 |
+
ne += ne_f
|
| 82 |
+
nc += nc_f
|
| 83 |
+
if im_file:
|
| 84 |
+
x["labels"].append(
|
| 85 |
+
dict(
|
| 86 |
+
im_file=im_file,
|
| 87 |
+
shape=shape,
|
| 88 |
+
cls=lb[:, 0:1], # n, 1
|
| 89 |
+
bboxes=lb[:, 1:], # n, 4
|
| 90 |
+
segments=segments,
|
| 91 |
+
keypoints=keypoint,
|
| 92 |
+
normalized=True,
|
| 93 |
+
bbox_format="xywh",
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
if msg:
|
| 97 |
+
msgs.append(msg)
|
| 98 |
+
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
| 99 |
+
pbar.close()
|
| 100 |
+
|
| 101 |
+
if msgs:
|
| 102 |
+
LOGGER.info("\n".join(msgs))
|
| 103 |
+
if nf == 0:
|
| 104 |
+
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
| 105 |
+
x["hash"] = get_hash(self.label_files + self.im_files)
|
| 106 |
+
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
| 107 |
+
x["msgs"] = msgs # warnings
|
| 108 |
+
save_dataset_cache_file(self.prefix, path, x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
def get_labels(self):
|
| 112 |
+
"""Returns dictionary of labels for YOLO training."""
|
| 113 |
+
self.label_files = img2label_paths(self.im_files)
|
| 114 |
+
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
| 115 |
+
try:
|
| 116 |
+
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
| 117 |
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
| 118 |
+
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
| 119 |
+
except (FileNotFoundError, AssertionError, AttributeError):
|
| 120 |
+
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
| 121 |
+
|
| 122 |
+
# Display cache
|
| 123 |
+
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
| 124 |
+
if exists and LOCAL_RANK in (-1, 0):
|
| 125 |
+
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
| 126 |
+
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
| 127 |
+
if cache["msgs"]:
|
| 128 |
+
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
| 129 |
+
|
| 130 |
+
# Read cache
|
| 131 |
+
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
| 132 |
+
labels = cache["labels"]
|
| 133 |
+
if not labels:
|
| 134 |
+
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
|
| 135 |
+
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
| 136 |
+
|
| 137 |
+
# Check if the dataset is all boxes or all segments
|
| 138 |
+
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
|
| 139 |
+
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
| 140 |
+
if len_segments and len_boxes != len_segments:
|
| 141 |
+
LOGGER.warning(
|
| 142 |
+
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
| 143 |
+
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
| 144 |
+
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
|
| 145 |
+
)
|
| 146 |
+
for lb in labels:
|
| 147 |
+
lb["segments"] = []
|
| 148 |
+
if len_cls == 0:
|
| 149 |
+
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")
|
| 150 |
+
return labels
|
| 151 |
+
|
| 152 |
+
def build_transforms(self, hyp=None):
|
| 153 |
+
"""Builds and appends transforms to the list."""
|
| 154 |
+
if self.augment:
|
| 155 |
+
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
| 156 |
+
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
| 157 |
+
transforms = v8_transforms(self, self.imgsz, hyp)
|
| 158 |
+
else:
|
| 159 |
+
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
| 160 |
+
transforms.append(
|
| 161 |
+
Format(
|
| 162 |
+
bbox_format="xywh",
|
| 163 |
+
normalize=True,
|
| 164 |
+
return_mask=self.use_segments,
|
| 165 |
+
return_keypoint=self.use_keypoints,
|
| 166 |
+
return_obb=self.use_obb,
|
| 167 |
+
batch_idx=True,
|
| 168 |
+
mask_ratio=hyp.mask_ratio,
|
| 169 |
+
mask_overlap=hyp.overlap_mask,
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
return transforms
|
| 173 |
+
|
| 174 |
+
def close_mosaic(self, hyp):
|
| 175 |
+
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
|
| 176 |
+
hyp.mosaic = 0.0 # set mosaic ratio=0.0
|
| 177 |
+
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
|
| 178 |
+
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
|
| 179 |
+
self.transforms = self.build_transforms(hyp)
|
| 180 |
+
|
| 181 |
+
def update_labels_info(self, label):
|
| 182 |
+
"""
|
| 183 |
+
Custom your label format here.
|
| 184 |
+
|
| 185 |
+
Note:
|
| 186 |
+
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
| 187 |
+
Can also support classification and semantic segmentation by adding or removing dict keys there.
|
| 188 |
+
"""
|
| 189 |
+
bboxes = label.pop("bboxes")
|
| 190 |
+
segments = label.pop("segments", [])
|
| 191 |
+
keypoints = label.pop("keypoints", None)
|
| 192 |
+
bbox_format = label.pop("bbox_format")
|
| 193 |
+
normalized = label.pop("normalized")
|
| 194 |
+
|
| 195 |
+
# NOTE: do NOT resample oriented boxes
|
| 196 |
+
segment_resamples = 100 if self.use_obb else 1000
|
| 197 |
+
if len(segments) > 0:
|
| 198 |
+
# list[np.array(1000, 2)] * num_samples
|
| 199 |
+
# (N, 1000, 2)
|
| 200 |
+
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
| 201 |
+
else:
|
| 202 |
+
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
| 203 |
+
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
| 204 |
+
return label
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def collate_fn(batch):
|
| 208 |
+
"""Collates data samples into batches."""
|
| 209 |
+
new_batch = {}
|
| 210 |
+
keys = batch[0].keys()
|
| 211 |
+
values = list(zip(*[list(b.values()) for b in batch]))
|
| 212 |
+
for i, k in enumerate(keys):
|
| 213 |
+
value = values[i]
|
| 214 |
+
if k == "img":
|
| 215 |
+
value = torch.stack(value, 0)
|
| 216 |
+
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
|
| 217 |
+
value = torch.cat(value, 0)
|
| 218 |
+
new_batch[k] = value
|
| 219 |
+
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
| 220 |
+
for i in range(len(new_batch["batch_idx"])):
|
| 221 |
+
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
| 222 |
+
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
| 223 |
+
return new_batch
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Classification dataloaders -------------------------------------------------------------------------------------------
|
| 227 |
+
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
| 228 |
+
"""
|
| 229 |
+
YOLO Classification Dataset.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
root (str): Dataset path.
|
| 233 |
+
|
| 234 |
+
Attributes:
|
| 235 |
+
cache_ram (bool): True if images should be cached in RAM, False otherwise.
|
| 236 |
+
cache_disk (bool): True if images should be cached on disk, False otherwise.
|
| 237 |
+
samples (list): List of samples containing file, index, npy, and im.
|
| 238 |
+
torch_transforms (callable): torchvision transforms applied to the dataset.
|
| 239 |
+
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(self, root, args, augment=False, cache=False, prefix=""):
|
| 243 |
+
"""
|
| 244 |
+
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
root (str): Dataset path.
|
| 248 |
+
args (Namespace): Argument parser containing dataset related settings.
|
| 249 |
+
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
|
| 250 |
+
cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
|
| 251 |
+
"""
|
| 252 |
+
super().__init__(root=root)
|
| 253 |
+
if augment and args.fraction < 1.0: # reduce training fraction
|
| 254 |
+
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
| 255 |
+
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
| 256 |
+
self.cache_ram = cache is True or cache == "ram"
|
| 257 |
+
self.cache_disk = cache == "disk"
|
| 258 |
+
self.samples = self.verify_images() # filter out bad images
|
| 259 |
+
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
| 260 |
+
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
| 261 |
+
self.torch_transforms = (
|
| 262 |
+
classify_augmentations(
|
| 263 |
+
size=args.imgsz,
|
| 264 |
+
scale=scale,
|
| 265 |
+
hflip=args.fliplr,
|
| 266 |
+
vflip=args.flipud,
|
| 267 |
+
erasing=args.erasing,
|
| 268 |
+
auto_augment=args.auto_augment,
|
| 269 |
+
hsv_h=args.hsv_h,
|
| 270 |
+
hsv_s=args.hsv_s,
|
| 271 |
+
hsv_v=args.hsv_v,
|
| 272 |
+
)
|
| 273 |
+
if augment
|
| 274 |
+
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def __getitem__(self, i):
|
| 278 |
+
"""Returns subset of data and targets corresponding to given indices."""
|
| 279 |
+
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
| 280 |
+
if self.cache_ram and im is None:
|
| 281 |
+
im = self.samples[i][3] = cv2.imread(f)
|
| 282 |
+
elif self.cache_disk:
|
| 283 |
+
if not fn.exists(): # load npy
|
| 284 |
+
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
| 285 |
+
im = np.load(fn)
|
| 286 |
+
else: # read image
|
| 287 |
+
im = cv2.imread(f) # BGR
|
| 288 |
+
# Convert NumPy array to PIL image
|
| 289 |
+
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
| 290 |
+
sample = self.torch_transforms(im)
|
| 291 |
+
return {"img": sample, "cls": j}
|
| 292 |
+
|
| 293 |
+
def __len__(self) -> int:
|
| 294 |
+
"""Return the total number of samples in the dataset."""
|
| 295 |
+
return len(self.samples)
|
| 296 |
+
|
| 297 |
+
def verify_images(self):
|
| 298 |
+
"""Verify all images in dataset."""
|
| 299 |
+
desc = f"{self.prefix}Scanning {self.root}..."
|
| 300 |
+
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
| 301 |
+
|
| 302 |
+
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
| 303 |
+
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
| 304 |
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
| 305 |
+
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
| 306 |
+
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
| 307 |
+
if LOCAL_RANK in (-1, 0):
|
| 308 |
+
d = f"{desc} {nf} images, {nc} corrupt"
|
| 309 |
+
TQDM(None, desc=d, total=n, initial=n)
|
| 310 |
+
if cache["msgs"]:
|
| 311 |
+
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
| 312 |
+
return samples
|
| 313 |
+
|
| 314 |
+
# Run scan if *.cache retrieval failed
|
| 315 |
+
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
| 316 |
+
with ThreadPool(NUM_THREADS) as pool:
|
| 317 |
+
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
| 318 |
+
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
| 319 |
+
for sample, nf_f, nc_f, msg in pbar:
|
| 320 |
+
if nf_f:
|
| 321 |
+
samples.append(sample)
|
| 322 |
+
if msg:
|
| 323 |
+
msgs.append(msg)
|
| 324 |
+
nf += nf_f
|
| 325 |
+
nc += nc_f
|
| 326 |
+
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
| 327 |
+
pbar.close()
|
| 328 |
+
if msgs:
|
| 329 |
+
LOGGER.info("\n".join(msgs))
|
| 330 |
+
x["hash"] = get_hash([x[0] for x in self.samples])
|
| 331 |
+
x["results"] = nf, nc, len(samples), samples
|
| 332 |
+
x["msgs"] = msgs # warnings
|
| 333 |
+
save_dataset_cache_file(self.prefix, path, x)
|
| 334 |
+
return samples
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def load_dataset_cache_file(path):
|
| 338 |
+
"""Load an Ultralytics *.cache dictionary from path."""
|
| 339 |
+
import gc
|
| 340 |
+
|
| 341 |
+
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
| 342 |
+
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
| 343 |
+
gc.enable()
|
| 344 |
+
return cache
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def save_dataset_cache_file(prefix, path, x):
|
| 348 |
+
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
| 349 |
+
x["version"] = DATASET_CACHE_VERSION # add cache version
|
| 350 |
+
if is_dir_writeable(path.parent):
|
| 351 |
+
if path.exists():
|
| 352 |
+
path.unlink() # remove *.cache file if exists
|
| 353 |
+
np.save(str(path), x) # save cache for next time
|
| 354 |
+
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
| 355 |
+
LOGGER.info(f"{prefix}New cache created: {path}")
|
| 356 |
+
else:
|
| 357 |
+
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# TODO: support semantic segmentation
|
| 361 |
+
class SemanticDataset(BaseDataset):
|
| 362 |
+
"""
|
| 363 |
+
Semantic Segmentation Dataset.
|
| 364 |
+
|
| 365 |
+
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
|
| 366 |
+
from the BaseDataset class.
|
| 367 |
+
|
| 368 |
+
Note:
|
| 369 |
+
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
|
| 370 |
+
semantic segmentation tasks.
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
def __init__(self):
|
| 374 |
+
"""Initialize a SemanticDataset object."""
|
| 375 |
+
super().__init__()
|
yolov8_model/ultralytics/data/explorer/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from .utils import plot_query_result
|
| 4 |
+
|
| 5 |
+
__all__ = ["plot_query_result"]
|
yolov8_model/ultralytics/data/explorer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (248 Bytes). View file
|
|
|
yolov8_model/ultralytics/data/explorer/__pycache__/explorer.cpython-310.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
yolov8_model/ultralytics/data/explorer/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (7.39 kB). View file
|
|
|
yolov8_model/ultralytics/data/explorer/explorer.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, List, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from matplotlib import pyplot as plt
|
| 12 |
+
from pandas import DataFrame
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from yolov8_model.ultralytics.data.augment import Format
|
| 16 |
+
from yolov8_model.ultralytics.data.dataset import YOLODataset
|
| 17 |
+
from yolov8_model.ultralytics.data.utils import check_det_dataset
|
| 18 |
+
from yolov8_model.ultralytics.models.yolo.model import YOLO
|
| 19 |
+
from yolov8_model.ultralytics.utils import LOGGER, IterableSimpleNamespace, checks, USER_CONFIG_DIR
|
| 20 |
+
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ExplorerDataset(YOLODataset):
|
| 24 |
+
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
| 25 |
+
super().__init__(*args, data=data, **kwargs)
|
| 26 |
+
|
| 27 |
+
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
| 28 |
+
"""Loads 1 image from dataset index 'i' without any resize ops."""
|
| 29 |
+
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
| 30 |
+
if im is None: # not cached in RAM
|
| 31 |
+
if fn.exists(): # load npy
|
| 32 |
+
im = np.load(fn)
|
| 33 |
+
else: # read image
|
| 34 |
+
im = cv2.imread(f) # BGR
|
| 35 |
+
if im is None:
|
| 36 |
+
raise FileNotFoundError(f"Image Not Found {f}")
|
| 37 |
+
h0, w0 = im.shape[:2] # orig hw
|
| 38 |
+
return im, (h0, w0), im.shape[:2]
|
| 39 |
+
|
| 40 |
+
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
| 41 |
+
|
| 42 |
+
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
| 43 |
+
"""Creates transforms for dataset images without resizing."""
|
| 44 |
+
return Format(
|
| 45 |
+
bbox_format="xyxy",
|
| 46 |
+
normalize=False,
|
| 47 |
+
return_mask=self.use_segments,
|
| 48 |
+
return_keypoint=self.use_keypoints,
|
| 49 |
+
batch_idx=True,
|
| 50 |
+
mask_ratio=hyp.mask_ratio,
|
| 51 |
+
mask_overlap=hyp.overlap_mask,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Explorer:
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
data: Union[str, Path] = "coco128.yaml",
|
| 59 |
+
model: str = "yolov8n.pt",
|
| 60 |
+
uri: str = USER_CONFIG_DIR / "explorer",
|
| 61 |
+
) -> None:
|
| 62 |
+
checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
|
| 63 |
+
import lancedb
|
| 64 |
+
|
| 65 |
+
self.connection = lancedb.connect(uri)
|
| 66 |
+
self.table_name = Path(data).name.lower() + "_" + model.lower()
|
| 67 |
+
self.sim_idx_base_name = (
|
| 68 |
+
f"{self.table_name}_sim_idx".lower()
|
| 69 |
+
) # Use this name and append thres and top_k to reuse the table
|
| 70 |
+
self.model = YOLO(model)
|
| 71 |
+
self.data = data # None
|
| 72 |
+
self.choice_set = None
|
| 73 |
+
|
| 74 |
+
self.table = None
|
| 75 |
+
self.progress = 0
|
| 76 |
+
|
| 77 |
+
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
| 78 |
+
"""
|
| 79 |
+
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
| 80 |
+
already exists. Pass force=True to overwrite the existing table.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
force (bool): Whether to overwrite the existing table or not. Defaults to False.
|
| 84 |
+
split (str): Split of the dataset to use. Defaults to 'train'.
|
| 85 |
+
|
| 86 |
+
Example:
|
| 87 |
+
```python
|
| 88 |
+
exp = Explorer()
|
| 89 |
+
exp.create_embeddings_table()
|
| 90 |
+
```
|
| 91 |
+
"""
|
| 92 |
+
if self.table is not None and not force:
|
| 93 |
+
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
| 94 |
+
return
|
| 95 |
+
if self.table_name in self.connection.table_names() and not force:
|
| 96 |
+
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
| 97 |
+
self.table = self.connection.open_table(self.table_name)
|
| 98 |
+
self.progress = 1
|
| 99 |
+
return
|
| 100 |
+
if self.data is None:
|
| 101 |
+
raise ValueError("Data must be provided to create embeddings table")
|
| 102 |
+
|
| 103 |
+
data_info = check_det_dataset(self.data)
|
| 104 |
+
if split not in data_info:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
choice_set = data_info[split]
|
| 110 |
+
choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
|
| 111 |
+
self.choice_set = choice_set
|
| 112 |
+
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
|
| 113 |
+
|
| 114 |
+
# Create the table schema
|
| 115 |
+
batch = dataset[0]
|
| 116 |
+
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
| 117 |
+
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
| 118 |
+
table.add(
|
| 119 |
+
self._yield_batches(
|
| 120 |
+
dataset,
|
| 121 |
+
data_info,
|
| 122 |
+
self.model,
|
| 123 |
+
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
self.table = table
|
| 128 |
+
|
| 129 |
+
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
| 130 |
+
"""Generates batches of data for embedding, excluding specified keys."""
|
| 131 |
+
for i in tqdm(range(len(dataset))):
|
| 132 |
+
self.progress = float(i + 1) / len(dataset)
|
| 133 |
+
batch = dataset[i]
|
| 134 |
+
for k in exclude_keys:
|
| 135 |
+
batch.pop(k, None)
|
| 136 |
+
batch = sanitize_batch(batch, data_info)
|
| 137 |
+
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
| 138 |
+
yield [batch]
|
| 139 |
+
|
| 140 |
+
def query(
|
| 141 |
+
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
| 142 |
+
) -> Any: # pyarrow.Table
|
| 143 |
+
"""
|
| 144 |
+
Query the table for similar images. Accepts a single image or a list of images.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
imgs (str or list): Path to the image or a list of paths to the images.
|
| 148 |
+
limit (int): Number of results to return.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
(pyarrow.Table): An arrow table containing the results. Supports converting to:
|
| 152 |
+
- pandas dataframe: `result.to_pandas()`
|
| 153 |
+
- dict of lists: `result.to_pydict()`
|
| 154 |
+
|
| 155 |
+
Example:
|
| 156 |
+
```python
|
| 157 |
+
exp = Explorer()
|
| 158 |
+
exp.create_embeddings_table()
|
| 159 |
+
similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
|
| 160 |
+
```
|
| 161 |
+
"""
|
| 162 |
+
if self.table is None:
|
| 163 |
+
raise ValueError("Table is not created. Please create the table first.")
|
| 164 |
+
if isinstance(imgs, str):
|
| 165 |
+
imgs = [imgs]
|
| 166 |
+
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
| 167 |
+
embeds = self.model.embed(imgs)
|
| 168 |
+
# Get avg if multiple images are passed (len > 1)
|
| 169 |
+
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
| 170 |
+
return self.table.search(embeds).limit(limit).to_arrow()
|
| 171 |
+
|
| 172 |
+
def sql_query(
|
| 173 |
+
self, query: str, return_type: str = "pandas"
|
| 174 |
+
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
| 175 |
+
"""
|
| 176 |
+
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
query (str): SQL query to run.
|
| 180 |
+
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
(pyarrow.Table): An arrow table containing the results.
|
| 184 |
+
|
| 185 |
+
Example:
|
| 186 |
+
```python
|
| 187 |
+
exp = Explorer()
|
| 188 |
+
exp.create_embeddings_table()
|
| 189 |
+
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
| 190 |
+
result = exp.sql_query(query)
|
| 191 |
+
```
|
| 192 |
+
"""
|
| 193 |
+
assert return_type in {
|
| 194 |
+
"pandas",
|
| 195 |
+
"arrow",
|
| 196 |
+
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
| 197 |
+
import duckdb
|
| 198 |
+
|
| 199 |
+
if self.table is None:
|
| 200 |
+
raise ValueError("Table is not created. Please create the table first.")
|
| 201 |
+
|
| 202 |
+
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
| 203 |
+
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
| 204 |
+
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
| 205 |
+
raise ValueError(
|
| 206 |
+
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
|
| 207 |
+
)
|
| 208 |
+
if query.startswith("WHERE"):
|
| 209 |
+
query = f"SELECT * FROM 'table' {query}"
|
| 210 |
+
LOGGER.info(f"Running query: {query}")
|
| 211 |
+
|
| 212 |
+
rs = duckdb.sql(query)
|
| 213 |
+
if return_type == "arrow":
|
| 214 |
+
return rs.arrow()
|
| 215 |
+
elif return_type == "pandas":
|
| 216 |
+
return rs.df()
|
| 217 |
+
|
| 218 |
+
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
| 219 |
+
"""
|
| 220 |
+
Plot the results of a SQL-Like query on the table.
|
| 221 |
+
Args:
|
| 222 |
+
query (str): SQL query to run.
|
| 223 |
+
labels (bool): Whether to plot the labels or not.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
(PIL.Image): Image containing the plot.
|
| 227 |
+
|
| 228 |
+
Example:
|
| 229 |
+
```python
|
| 230 |
+
exp = Explorer()
|
| 231 |
+
exp.create_embeddings_table()
|
| 232 |
+
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
| 233 |
+
result = exp.plot_sql_query(query)
|
| 234 |
+
```
|
| 235 |
+
"""
|
| 236 |
+
result = self.sql_query(query, return_type="arrow")
|
| 237 |
+
if len(result) == 0:
|
| 238 |
+
LOGGER.info("No results found.")
|
| 239 |
+
return None
|
| 240 |
+
img = plot_query_result(result, plot_labels=labels)
|
| 241 |
+
return Image.fromarray(img)
|
| 242 |
+
|
| 243 |
+
def get_similar(
|
| 244 |
+
self,
|
| 245 |
+
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
| 246 |
+
idx: Union[int, List[int]] = None,
|
| 247 |
+
limit: int = 25,
|
| 248 |
+
return_type: str = "pandas",
|
| 249 |
+
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
| 250 |
+
"""
|
| 251 |
+
Query the table for similar images. Accepts a single image or a list of images.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
img (str or list): Path to the image or a list of paths to the images.
|
| 255 |
+
idx (int or list): Index of the image in the table or a list of indexes.
|
| 256 |
+
limit (int): Number of results to return. Defaults to 25.
|
| 257 |
+
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
(pandas.DataFrame): A dataframe containing the results.
|
| 261 |
+
|
| 262 |
+
Example:
|
| 263 |
+
```python
|
| 264 |
+
exp = Explorer()
|
| 265 |
+
exp.create_embeddings_table()
|
| 266 |
+
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
| 267 |
+
```
|
| 268 |
+
"""
|
| 269 |
+
assert return_type in {
|
| 270 |
+
"pandas",
|
| 271 |
+
"arrow",
|
| 272 |
+
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
| 273 |
+
img = self._check_imgs_or_idxs(img, idx)
|
| 274 |
+
similar = self.query(img, limit=limit)
|
| 275 |
+
|
| 276 |
+
if return_type == "arrow":
|
| 277 |
+
return similar
|
| 278 |
+
elif return_type == "pandas":
|
| 279 |
+
return similar.to_pandas()
|
| 280 |
+
|
| 281 |
+
def plot_similar(
|
| 282 |
+
self,
|
| 283 |
+
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
| 284 |
+
idx: Union[int, List[int]] = None,
|
| 285 |
+
limit: int = 25,
|
| 286 |
+
labels: bool = True,
|
| 287 |
+
) -> Image.Image:
|
| 288 |
+
"""
|
| 289 |
+
Plot the similar images. Accepts images or indexes.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
img (str or list): Path to the image or a list of paths to the images.
|
| 293 |
+
idx (int or list): Index of the image in the table or a list of indexes.
|
| 294 |
+
labels (bool): Whether to plot the labels or not.
|
| 295 |
+
limit (int): Number of results to return. Defaults to 25.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
(PIL.Image): Image containing the plot.
|
| 299 |
+
|
| 300 |
+
Example:
|
| 301 |
+
```python
|
| 302 |
+
exp = Explorer()
|
| 303 |
+
exp.create_embeddings_table()
|
| 304 |
+
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
| 305 |
+
```
|
| 306 |
+
"""
|
| 307 |
+
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
| 308 |
+
if len(similar) == 0:
|
| 309 |
+
LOGGER.info("No results found.")
|
| 310 |
+
return None
|
| 311 |
+
img = plot_query_result(similar, plot_labels=labels)
|
| 312 |
+
return Image.fromarray(img)
|
| 313 |
+
|
| 314 |
+
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
| 315 |
+
"""
|
| 316 |
+
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
| 317 |
+
are max_dist or closer to the image in the embedding space at a given index.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
| 321 |
+
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
| 322 |
+
vector search. Defaults: None.
|
| 323 |
+
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
|
| 327 |
+
include indices of similar images and their respective distances.
|
| 328 |
+
|
| 329 |
+
Example:
|
| 330 |
+
```python
|
| 331 |
+
exp = Explorer()
|
| 332 |
+
exp.create_embeddings_table()
|
| 333 |
+
sim_idx = exp.similarity_index()
|
| 334 |
+
```
|
| 335 |
+
"""
|
| 336 |
+
if self.table is None:
|
| 337 |
+
raise ValueError("Table is not created. Please create the table first.")
|
| 338 |
+
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
| 339 |
+
if sim_idx_table_name in self.connection.table_names() and not force:
|
| 340 |
+
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
| 341 |
+
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
| 342 |
+
|
| 343 |
+
if top_k and not (1.0 >= top_k >= 0.0):
|
| 344 |
+
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
| 345 |
+
if max_dist < 0.0:
|
| 346 |
+
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
| 347 |
+
|
| 348 |
+
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
| 349 |
+
top_k = max(top_k, 1)
|
| 350 |
+
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
| 351 |
+
im_files = features["im_file"]
|
| 352 |
+
embeddings = features["vector"]
|
| 353 |
+
|
| 354 |
+
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
| 355 |
+
|
| 356 |
+
def _yield_sim_idx():
|
| 357 |
+
"""Generates a dataframe with similarity indices and distances for images."""
|
| 358 |
+
for i in tqdm(range(len(embeddings))):
|
| 359 |
+
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
| 360 |
+
yield [
|
| 361 |
+
{
|
| 362 |
+
"idx": i,
|
| 363 |
+
"im_file": im_files[i],
|
| 364 |
+
"count": len(sim_idx),
|
| 365 |
+
"sim_im_files": sim_idx["im_file"].tolist(),
|
| 366 |
+
}
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
sim_table.add(_yield_sim_idx())
|
| 370 |
+
self.sim_index = sim_table
|
| 371 |
+
return sim_table.to_pandas()
|
| 372 |
+
|
| 373 |
+
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
|
| 374 |
+
"""
|
| 375 |
+
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
| 376 |
+
max_dist or closer to the image in the embedding space at a given index.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
| 380 |
+
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
|
| 381 |
+
running vector search. Defaults to 0.01.
|
| 382 |
+
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
(PIL.Image): Image containing the plot.
|
| 386 |
+
|
| 387 |
+
Example:
|
| 388 |
+
```python
|
| 389 |
+
exp = Explorer()
|
| 390 |
+
exp.create_embeddings_table()
|
| 391 |
+
|
| 392 |
+
similarity_idx_plot = exp.plot_similarity_index()
|
| 393 |
+
similarity_idx_plot.show() # view image preview
|
| 394 |
+
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
|
| 395 |
+
```
|
| 396 |
+
"""
|
| 397 |
+
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
| 398 |
+
sim_count = sim_idx["count"].tolist()
|
| 399 |
+
sim_count = np.array(sim_count)
|
| 400 |
+
|
| 401 |
+
indices = np.arange(len(sim_count))
|
| 402 |
+
|
| 403 |
+
# Create the bar plot
|
| 404 |
+
plt.bar(indices, sim_count)
|
| 405 |
+
|
| 406 |
+
# Customize the plot (optional)
|
| 407 |
+
plt.xlabel("data idx")
|
| 408 |
+
plt.ylabel("Count")
|
| 409 |
+
plt.title("Similarity Count")
|
| 410 |
+
buffer = BytesIO()
|
| 411 |
+
plt.savefig(buffer, format="png")
|
| 412 |
+
buffer.seek(0)
|
| 413 |
+
|
| 414 |
+
# Use Pillow to open the image from the buffer
|
| 415 |
+
return Image.fromarray(np.array(Image.open(buffer)))
|
| 416 |
+
|
| 417 |
+
def _check_imgs_or_idxs(
|
| 418 |
+
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
| 419 |
+
) -> List[np.ndarray]:
|
| 420 |
+
if img is None and idx is None:
|
| 421 |
+
raise ValueError("Either img or idx must be provided.")
|
| 422 |
+
if img is not None and idx is not None:
|
| 423 |
+
raise ValueError("Only one of img or idx must be provided.")
|
| 424 |
+
if idx is not None:
|
| 425 |
+
idx = idx if isinstance(idx, list) else [idx]
|
| 426 |
+
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
| 427 |
+
|
| 428 |
+
return img if isinstance(img, list) else [img]
|
| 429 |
+
|
| 430 |
+
def ask_ai(self, query):
|
| 431 |
+
"""
|
| 432 |
+
Ask AI a question.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
query (str): Question to ask.
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
(pandas.DataFrame): A dataframe containing filtered results to the SQL query.
|
| 439 |
+
|
| 440 |
+
Example:
|
| 441 |
+
```python
|
| 442 |
+
exp = Explorer()
|
| 443 |
+
exp.create_embeddings_table()
|
| 444 |
+
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
|
| 445 |
+
```
|
| 446 |
+
"""
|
| 447 |
+
result = prompt_sql_query(query)
|
| 448 |
+
try:
|
| 449 |
+
df = self.sql_query(result)
|
| 450 |
+
except Exception as e:
|
| 451 |
+
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
| 452 |
+
LOGGER.error(e)
|
| 453 |
+
return None
|
| 454 |
+
return df
|
| 455 |
+
|
| 456 |
+
def visualize(self, result):
|
| 457 |
+
"""
|
| 458 |
+
Visualize the results of a query. TODO.
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
result (pyarrow.Table): Table containing the results of a query.
|
| 462 |
+
"""
|
| 463 |
+
pass
|
| 464 |
+
|
| 465 |
+
def generate_report(self, result):
|
| 466 |
+
"""
|
| 467 |
+
Generate a report of the dataset.
|
| 468 |
+
|
| 469 |
+
TODO
|
| 470 |
+
"""
|
| 471 |
+
pass
|
yolov8_model/ultralytics/data/explorer/gui/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
yolov8_model/ultralytics/data/explorer/gui/dash.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from threading import Thread
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from ultralytics import Explorer
|
| 9 |
+
from ultralytics.utils import ROOT, SETTINGS
|
| 10 |
+
from ultralytics.utils.checks import check_requirements
|
| 11 |
+
|
| 12 |
+
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3"))
|
| 13 |
+
|
| 14 |
+
import streamlit as st
|
| 15 |
+
from streamlit_select import image_select
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_explorer():
|
| 19 |
+
"""Initializes and returns an instance of the Explorer class."""
|
| 20 |
+
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
|
| 21 |
+
thread = Thread(
|
| 22 |
+
target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
|
| 23 |
+
)
|
| 24 |
+
thread.start()
|
| 25 |
+
progress_bar = st.progress(0, text="Creating embeddings table...")
|
| 26 |
+
while exp.progress < 1:
|
| 27 |
+
time.sleep(0.1)
|
| 28 |
+
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
|
| 29 |
+
thread.join()
|
| 30 |
+
st.session_state["explorer"] = exp
|
| 31 |
+
progress_bar.empty()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def init_explorer_form():
|
| 35 |
+
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
|
| 36 |
+
datasets = ROOT / "cfg" / "datasets"
|
| 37 |
+
ds = [d.name for d in datasets.glob("*.yaml")]
|
| 38 |
+
models = [
|
| 39 |
+
"yolov8n.pt",
|
| 40 |
+
"yolov8s.pt",
|
| 41 |
+
"yolov8m.pt",
|
| 42 |
+
"yolov8l.pt",
|
| 43 |
+
"yolov8x.pt",
|
| 44 |
+
"yolov8n-seg.pt",
|
| 45 |
+
"yolov8s-seg.pt",
|
| 46 |
+
"yolov8m-seg.pt",
|
| 47 |
+
"yolov8l-seg.pt",
|
| 48 |
+
"yolov8x-seg.pt",
|
| 49 |
+
"yolov8n-pose.pt",
|
| 50 |
+
"yolov8s-pose.pt",
|
| 51 |
+
"yolov8m-pose.pt",
|
| 52 |
+
"yolov8l-pose.pt",
|
| 53 |
+
"yolov8x-pose.pt",
|
| 54 |
+
]
|
| 55 |
+
with st.form(key="explorer_init_form"):
|
| 56 |
+
col1, col2 = st.columns(2)
|
| 57 |
+
with col1:
|
| 58 |
+
st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
|
| 59 |
+
with col2:
|
| 60 |
+
st.selectbox("Select model", models, key="model")
|
| 61 |
+
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
|
| 62 |
+
|
| 63 |
+
st.form_submit_button("Explore", on_click=_get_explorer)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def query_form():
|
| 67 |
+
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
|
| 68 |
+
with st.form("query_form"):
|
| 69 |
+
col1, col2 = st.columns([0.8, 0.2])
|
| 70 |
+
with col1:
|
| 71 |
+
st.text_input(
|
| 72 |
+
"Query",
|
| 73 |
+
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
| 74 |
+
label_visibility="collapsed",
|
| 75 |
+
key="query",
|
| 76 |
+
)
|
| 77 |
+
with col2:
|
| 78 |
+
st.form_submit_button("Query", on_click=run_sql_query)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def ai_query_form():
|
| 82 |
+
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
|
| 83 |
+
with st.form("ai_query_form"):
|
| 84 |
+
col1, col2 = st.columns([0.8, 0.2])
|
| 85 |
+
with col1:
|
| 86 |
+
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
|
| 87 |
+
with col2:
|
| 88 |
+
st.form_submit_button("Ask AI", on_click=run_ai_query)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def find_similar_imgs(imgs):
|
| 92 |
+
"""Initializes a Streamlit form for AI-based image querying with custom input."""
|
| 93 |
+
exp = st.session_state["explorer"]
|
| 94 |
+
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
|
| 95 |
+
paths = similar.to_pydict()["im_file"]
|
| 96 |
+
st.session_state["imgs"] = paths
|
| 97 |
+
st.session_state["res"] = similar
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def similarity_form(selected_imgs):
|
| 101 |
+
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
|
| 102 |
+
st.write("Similarity Search")
|
| 103 |
+
with st.form("similarity_form"):
|
| 104 |
+
subcol1, subcol2 = st.columns([1, 1])
|
| 105 |
+
with subcol1:
|
| 106 |
+
st.number_input(
|
| 107 |
+
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
with subcol2:
|
| 111 |
+
disabled = not len(selected_imgs)
|
| 112 |
+
st.write("Selected: ", len(selected_imgs))
|
| 113 |
+
st.form_submit_button(
|
| 114 |
+
"Search",
|
| 115 |
+
disabled=disabled,
|
| 116 |
+
on_click=find_similar_imgs,
|
| 117 |
+
args=(selected_imgs,),
|
| 118 |
+
)
|
| 119 |
+
if disabled:
|
| 120 |
+
st.error("Select at least one image to search.")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# def persist_reset_form():
|
| 124 |
+
# with st.form("persist_reset"):
|
| 125 |
+
# col1, col2 = st.columns([1, 1])
|
| 126 |
+
# with col1:
|
| 127 |
+
# st.form_submit_button("Reset", on_click=reset)
|
| 128 |
+
#
|
| 129 |
+
# with col2:
|
| 130 |
+
# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def run_sql_query():
|
| 134 |
+
"""Executes an SQL query and returns the results."""
|
| 135 |
+
st.session_state["error"] = None
|
| 136 |
+
query = st.session_state.get("query")
|
| 137 |
+
if query.rstrip().lstrip():
|
| 138 |
+
exp = st.session_state["explorer"]
|
| 139 |
+
res = exp.sql_query(query, return_type="arrow")
|
| 140 |
+
st.session_state["imgs"] = res.to_pydict()["im_file"]
|
| 141 |
+
st.session_state["res"] = res
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def run_ai_query():
|
| 145 |
+
"""Execute SQL query and update session state with query results."""
|
| 146 |
+
if not SETTINGS["openai_api_key"]:
|
| 147 |
+
st.session_state[
|
| 148 |
+
"error"
|
| 149 |
+
] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
| 150 |
+
return
|
| 151 |
+
st.session_state["error"] = None
|
| 152 |
+
query = st.session_state.get("ai_query")
|
| 153 |
+
if query.rstrip().lstrip():
|
| 154 |
+
exp = st.session_state["explorer"]
|
| 155 |
+
res = exp.ask_ai(query)
|
| 156 |
+
if not isinstance(res, pd.DataFrame) or res.empty:
|
| 157 |
+
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
| 158 |
+
return
|
| 159 |
+
st.session_state["imgs"] = res["im_file"].to_list()
|
| 160 |
+
st.session_state["res"] = res
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def reset_explorer():
|
| 164 |
+
"""Resets the explorer to its initial state by clearing session variables."""
|
| 165 |
+
st.session_state["explorer"] = None
|
| 166 |
+
st.session_state["imgs"] = None
|
| 167 |
+
st.session_state["error"] = None
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def utralytics_explorer_docs_callback():
|
| 171 |
+
"""Resets the explorer to its initial state by clearing session variables."""
|
| 172 |
+
with st.container(border=True):
|
| 173 |
+
st.image(
|
| 174 |
+
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
|
| 175 |
+
width=100,
|
| 176 |
+
)
|
| 177 |
+
st.markdown(
|
| 178 |
+
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
| 179 |
+
unsafe_allow_html=True,
|
| 180 |
+
help=None,
|
| 181 |
+
)
|
| 182 |
+
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def layout():
|
| 186 |
+
"""Resets explorer session variables and provides documentation with a link to API docs."""
|
| 187 |
+
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
|
| 188 |
+
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
|
| 189 |
+
|
| 190 |
+
if st.session_state.get("explorer") is None:
|
| 191 |
+
init_explorer_form()
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
|
| 195 |
+
exp = st.session_state.get("explorer")
|
| 196 |
+
col1, col2 = st.columns([0.75, 0.25], gap="small")
|
| 197 |
+
imgs = []
|
| 198 |
+
if st.session_state.get("error"):
|
| 199 |
+
st.error(st.session_state["error"])
|
| 200 |
+
else:
|
| 201 |
+
if st.session_state.get("imgs"):
|
| 202 |
+
imgs = st.session_state.get("imgs")
|
| 203 |
+
else:
|
| 204 |
+
imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
|
| 205 |
+
st.session_state["res"] = exp.table.to_arrow()
|
| 206 |
+
total_imgs, selected_imgs = len(imgs), []
|
| 207 |
+
with col1:
|
| 208 |
+
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
| 209 |
+
with subcol1:
|
| 210 |
+
st.write("Max Images Displayed:")
|
| 211 |
+
with subcol2:
|
| 212 |
+
num = st.number_input(
|
| 213 |
+
"Max Images Displayed",
|
| 214 |
+
min_value=0,
|
| 215 |
+
max_value=total_imgs,
|
| 216 |
+
value=min(500, total_imgs),
|
| 217 |
+
key="num_imgs_displayed",
|
| 218 |
+
label_visibility="collapsed",
|
| 219 |
+
)
|
| 220 |
+
with subcol3:
|
| 221 |
+
st.write("Start Index:")
|
| 222 |
+
with subcol4:
|
| 223 |
+
start_idx = st.number_input(
|
| 224 |
+
"Start Index",
|
| 225 |
+
min_value=0,
|
| 226 |
+
max_value=total_imgs,
|
| 227 |
+
value=0,
|
| 228 |
+
key="start_index",
|
| 229 |
+
label_visibility="collapsed",
|
| 230 |
+
)
|
| 231 |
+
with subcol5:
|
| 232 |
+
reset = st.button("Reset", use_container_width=False, key="reset")
|
| 233 |
+
if reset:
|
| 234 |
+
st.session_state["imgs"] = None
|
| 235 |
+
st.experimental_rerun()
|
| 236 |
+
|
| 237 |
+
query_form()
|
| 238 |
+
ai_query_form()
|
| 239 |
+
if total_imgs:
|
| 240 |
+
labels, boxes, masks, kpts, classes = None, None, None, None, None
|
| 241 |
+
task = exp.model.task
|
| 242 |
+
if st.session_state.get("display_labels"):
|
| 243 |
+
labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num]
|
| 244 |
+
boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num]
|
| 245 |
+
masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num]
|
| 246 |
+
kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num]
|
| 247 |
+
classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num]
|
| 248 |
+
imgs_displayed = imgs[start_idx : start_idx + num]
|
| 249 |
+
selected_imgs = image_select(
|
| 250 |
+
f"Total samples: {total_imgs}",
|
| 251 |
+
images=imgs_displayed,
|
| 252 |
+
use_container_width=False,
|
| 253 |
+
# indices=[i for i in range(num)] if select_all else None,
|
| 254 |
+
labels=labels,
|
| 255 |
+
classes=classes,
|
| 256 |
+
bboxes=boxes,
|
| 257 |
+
masks=masks if task == "segment" else None,
|
| 258 |
+
kpts=kpts if task == "pose" else None,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
with col2:
|
| 262 |
+
similarity_form(selected_imgs)
|
| 263 |
+
display_labels = st.checkbox("Labels", value=False, key="display_labels")
|
| 264 |
+
utralytics_explorer_docs_callback()
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
layout()
|
yolov8_model/ultralytics/data/explorer/utils.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import getpass
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from yolov8_model.ultralytics.data.augment import LetterBox
|
| 11 |
+
from yolov8_model.ultralytics.utils import LOGGER as logger
|
| 12 |
+
from yolov8_model.ultralytics.utils import SETTINGS
|
| 13 |
+
from yolov8_model.ultralytics.utils.checks import check_requirements
|
| 14 |
+
from yolov8_model.ultralytics.utils.ops import xyxy2xywh
|
| 15 |
+
from yolov8_model.ultralytics.utils.plotting import plot_images
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_table_schema(vector_size):
|
| 19 |
+
"""Extracts and returns the schema of a database table."""
|
| 20 |
+
from lancedb.pydantic import LanceModel, Vector
|
| 21 |
+
|
| 22 |
+
class Schema(LanceModel):
|
| 23 |
+
im_file: str
|
| 24 |
+
labels: List[str]
|
| 25 |
+
cls: List[int]
|
| 26 |
+
bboxes: List[List[float]]
|
| 27 |
+
masks: List[List[List[int]]]
|
| 28 |
+
keypoints: List[List[List[float]]]
|
| 29 |
+
vector: Vector(vector_size)
|
| 30 |
+
|
| 31 |
+
return Schema
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_sim_index_schema():
|
| 35 |
+
"""Returns a LanceModel schema for a database table with specified vector size."""
|
| 36 |
+
from lancedb.pydantic import LanceModel
|
| 37 |
+
|
| 38 |
+
class Schema(LanceModel):
|
| 39 |
+
idx: int
|
| 40 |
+
im_file: str
|
| 41 |
+
count: int
|
| 42 |
+
sim_im_files: List[str]
|
| 43 |
+
|
| 44 |
+
return Schema
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def sanitize_batch(batch, dataset_info):
|
| 48 |
+
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
|
| 49 |
+
batch["cls"] = batch["cls"].flatten().int().tolist()
|
| 50 |
+
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
|
| 51 |
+
batch["bboxes"] = [box for box, _ in box_cls_pair]
|
| 52 |
+
batch["cls"] = [cls for _, cls in box_cls_pair]
|
| 53 |
+
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
|
| 54 |
+
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
|
| 55 |
+
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
|
| 56 |
+
return batch
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def plot_query_result(similar_set, plot_labels=True):
|
| 60 |
+
"""
|
| 61 |
+
Plot images from the similar set.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
similar_set (list): Pyarrow or pandas object containing the similar data points
|
| 65 |
+
plot_labels (bool): Whether to plot labels or not
|
| 66 |
+
"""
|
| 67 |
+
similar_set = (
|
| 68 |
+
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
| 69 |
+
)
|
| 70 |
+
empty_masks = [[[]]]
|
| 71 |
+
empty_boxes = [[]]
|
| 72 |
+
images = similar_set.get("im_file", [])
|
| 73 |
+
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
|
| 74 |
+
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
|
| 75 |
+
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
|
| 76 |
+
cls = similar_set.get("cls", [])
|
| 77 |
+
|
| 78 |
+
plot_size = 640
|
| 79 |
+
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
| 80 |
+
for i, imf in enumerate(images):
|
| 81 |
+
im = cv2.imread(imf)
|
| 82 |
+
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
| 83 |
+
h, w = im.shape[:2]
|
| 84 |
+
r = min(plot_size / h, plot_size / w)
|
| 85 |
+
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
|
| 86 |
+
if plot_labels:
|
| 87 |
+
if len(bboxes) > i and len(bboxes[i]) > 0:
|
| 88 |
+
box = np.array(bboxes[i], dtype=np.float32)
|
| 89 |
+
box[:, [0, 2]] *= r
|
| 90 |
+
box[:, [1, 3]] *= r
|
| 91 |
+
plot_boxes.append(box)
|
| 92 |
+
if len(masks) > i and len(masks[i]) > 0:
|
| 93 |
+
mask = np.array(masks[i], dtype=np.uint8)[0]
|
| 94 |
+
plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
|
| 95 |
+
if len(kpts) > i and kpts[i] is not None:
|
| 96 |
+
kpt = np.array(kpts[i], dtype=np.float32)
|
| 97 |
+
kpt[:, :, :2] *= r
|
| 98 |
+
plot_kpts.append(kpt)
|
| 99 |
+
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
|
| 100 |
+
imgs = np.stack(imgs, axis=0)
|
| 101 |
+
masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8)
|
| 102 |
+
kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32)
|
| 103 |
+
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32)
|
| 104 |
+
batch_idx = np.concatenate(batch_idx, axis=0)
|
| 105 |
+
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
| 106 |
+
|
| 107 |
+
return plot_images(
|
| 108 |
+
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def prompt_sql_query(query):
|
| 113 |
+
"""Plots images with optional labels from a similar data set."""
|
| 114 |
+
check_requirements("openai>=1.6.1")
|
| 115 |
+
from openai import OpenAI
|
| 116 |
+
|
| 117 |
+
if not SETTINGS["openai_api_key"]:
|
| 118 |
+
logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
|
| 119 |
+
openai_api_key = getpass.getpass("OpenAI API key: ")
|
| 120 |
+
SETTINGS.update({"openai_api_key": openai_api_key})
|
| 121 |
+
openai = OpenAI(api_key=SETTINGS["openai_api_key"])
|
| 122 |
+
|
| 123 |
+
messages = [
|
| 124 |
+
{
|
| 125 |
+
"role": "system",
|
| 126 |
+
"content": """
|
| 127 |
+
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
| 128 |
+
the following schema and a user request. You only need to output the format with fixed selection
|
| 129 |
+
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
| 130 |
+
|
| 131 |
+
Schema:
|
| 132 |
+
im_file: string not null
|
| 133 |
+
labels: list<item: string> not null
|
| 134 |
+
child 0, item: string
|
| 135 |
+
cls: list<item: int64> not null
|
| 136 |
+
child 0, item: int64
|
| 137 |
+
bboxes: list<item: list<item: double>> not null
|
| 138 |
+
child 0, item: list<item: double>
|
| 139 |
+
child 0, item: double
|
| 140 |
+
masks: list<item: list<item: list<item: int64>>> not null
|
| 141 |
+
child 0, item: list<item: list<item: int64>>
|
| 142 |
+
child 0, item: list<item: int64>
|
| 143 |
+
child 0, item: int64
|
| 144 |
+
keypoints: list<item: list<item: list<item: double>>> not null
|
| 145 |
+
child 0, item: list<item: list<item: double>>
|
| 146 |
+
child 0, item: list<item: double>
|
| 147 |
+
child 0, item: double
|
| 148 |
+
vector: fixed_size_list<item: float>[256] not null
|
| 149 |
+
child 0, item: float
|
| 150 |
+
|
| 151 |
+
Some details about the schema:
|
| 152 |
+
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
|
| 153 |
+
in each image
|
| 154 |
+
- the "cls" column contains the integer values on these classes that map them the labels
|
| 155 |
+
|
| 156 |
+
Example of a correct query:
|
| 157 |
+
request - Get all data points that contain 2 or more people and at least one dog
|
| 158 |
+
correct query-
|
| 159 |
+
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
| 160 |
+
""",
|
| 161 |
+
},
|
| 162 |
+
{"role": "user", "content": f"{query}"},
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
|
| 166 |
+
return response.choices[0].message.content
|
yolov8_model/ultralytics/data/loaders.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import glob
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from threading import Thread
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import requests
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
from yolov8_model.ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
| 19 |
+
from yolov8_model.ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
|
| 20 |
+
from yolov8_model.ultralytics.utils.checks import check_requirements
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class SourceTypes:
|
| 25 |
+
"""Class to represent various types of input sources for predictions."""
|
| 26 |
+
|
| 27 |
+
webcam: bool = False
|
| 28 |
+
screenshot: bool = False
|
| 29 |
+
from_img: bool = False
|
| 30 |
+
tensor: bool = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LoadStreams:
|
| 34 |
+
"""
|
| 35 |
+
Stream Loader for various types of video streams.
|
| 36 |
+
|
| 37 |
+
Suitable for use with `yolo predict source='rtsp://example.com/media.mp4'`, supports RTSP, RTMP, HTTP, and TCP streams.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
sources (str): The source input paths or URLs for the video streams.
|
| 41 |
+
vid_stride (int): Video frame-rate stride, defaults to 1.
|
| 42 |
+
buffer (bool): Whether to buffer input streams, defaults to False.
|
| 43 |
+
running (bool): Flag to indicate if the streaming thread is running.
|
| 44 |
+
mode (str): Set to 'stream' indicating real-time capture.
|
| 45 |
+
imgs (list): List of image frames for each stream.
|
| 46 |
+
fps (list): List of FPS for each stream.
|
| 47 |
+
frames (list): List of total frames for each stream.
|
| 48 |
+
threads (list): List of threads for each stream.
|
| 49 |
+
shape (list): List of shapes for each stream.
|
| 50 |
+
caps (list): List of cv2.VideoCapture objects for each stream.
|
| 51 |
+
bs (int): Batch size for processing.
|
| 52 |
+
|
| 53 |
+
Methods:
|
| 54 |
+
__init__: Initialize the stream loader.
|
| 55 |
+
update: Read stream frames in daemon thread.
|
| 56 |
+
close: Close stream loader and release resources.
|
| 57 |
+
__iter__: Returns an iterator object for the class.
|
| 58 |
+
__next__: Returns source paths, transformed, and original images for processing.
|
| 59 |
+
__len__: Return the length of the sources object.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
|
| 63 |
+
"""Initialize instance variables and check for consistent input stream shapes."""
|
| 64 |
+
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
| 65 |
+
self.buffer = buffer # buffer input streams
|
| 66 |
+
self.running = True # running flag for Thread
|
| 67 |
+
self.mode = "stream"
|
| 68 |
+
self.vid_stride = vid_stride # video frame-rate stride
|
| 69 |
+
|
| 70 |
+
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
|
| 71 |
+
n = len(sources)
|
| 72 |
+
self.fps = [0] * n # frames per second
|
| 73 |
+
self.frames = [0] * n
|
| 74 |
+
self.threads = [None] * n
|
| 75 |
+
self.caps = [None] * n # video capture objects
|
| 76 |
+
self.imgs = [[] for _ in range(n)] # images
|
| 77 |
+
self.shape = [[] for _ in range(n)] # image shapes
|
| 78 |
+
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
| 79 |
+
for i, s in enumerate(sources): # index, source
|
| 80 |
+
# Start thread to read frames from video stream
|
| 81 |
+
st = f"{i + 1}/{n}: {s}... "
|
| 82 |
+
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
|
| 83 |
+
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
| 84 |
+
s = get_best_youtube_url(s)
|
| 85 |
+
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
| 86 |
+
if s == 0 and (is_colab() or is_kaggle()):
|
| 87 |
+
raise NotImplementedError(
|
| 88 |
+
"'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
| 89 |
+
"Try running 'source=0' in a local environment."
|
| 90 |
+
)
|
| 91 |
+
self.caps[i] = cv2.VideoCapture(s) # store video capture object
|
| 92 |
+
if not self.caps[i].isOpened():
|
| 93 |
+
raise ConnectionError(f"{st}Failed to open {s}")
|
| 94 |
+
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 95 |
+
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 96 |
+
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
| 97 |
+
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
|
| 98 |
+
"inf"
|
| 99 |
+
) # infinite stream fallback
|
| 100 |
+
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
| 101 |
+
|
| 102 |
+
success, im = self.caps[i].read() # guarantee first frame
|
| 103 |
+
if not success or im is None:
|
| 104 |
+
raise ConnectionError(f"{st}Failed to read images from {s}")
|
| 105 |
+
self.imgs[i].append(im)
|
| 106 |
+
self.shape[i] = im.shape
|
| 107 |
+
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
|
| 108 |
+
LOGGER.info(f"{st}Success �� ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
|
| 109 |
+
self.threads[i].start()
|
| 110 |
+
LOGGER.info("") # newline
|
| 111 |
+
|
| 112 |
+
# Check for common shapes
|
| 113 |
+
self.bs = self.__len__()
|
| 114 |
+
|
| 115 |
+
def update(self, i, cap, stream):
|
| 116 |
+
"""Read stream `i` frames in daemon thread."""
|
| 117 |
+
n, f = 0, self.frames[i] # frame number, frame array
|
| 118 |
+
while self.running and cap.isOpened() and n < (f - 1):
|
| 119 |
+
if len(self.imgs[i]) < 30: # keep a <=30-image buffer
|
| 120 |
+
n += 1
|
| 121 |
+
cap.grab() # .read() = .grab() followed by .retrieve()
|
| 122 |
+
if n % self.vid_stride == 0:
|
| 123 |
+
success, im = cap.retrieve()
|
| 124 |
+
if not success:
|
| 125 |
+
im = np.zeros(self.shape[i], dtype=np.uint8)
|
| 126 |
+
LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.")
|
| 127 |
+
cap.open(stream) # re-open stream if signal was lost
|
| 128 |
+
if self.buffer:
|
| 129 |
+
self.imgs[i].append(im)
|
| 130 |
+
else:
|
| 131 |
+
self.imgs[i] = [im]
|
| 132 |
+
else:
|
| 133 |
+
time.sleep(0.01) # wait until the buffer is empty
|
| 134 |
+
|
| 135 |
+
def close(self):
|
| 136 |
+
"""Close stream loader and release resources."""
|
| 137 |
+
self.running = False # stop flag for Thread
|
| 138 |
+
for thread in self.threads:
|
| 139 |
+
if thread.is_alive():
|
| 140 |
+
thread.join(timeout=5) # Add timeout
|
| 141 |
+
for cap in self.caps: # Iterate through the stored VideoCapture objects
|
| 142 |
+
try:
|
| 143 |
+
cap.release() # release video capture
|
| 144 |
+
except Exception as e:
|
| 145 |
+
LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}")
|
| 146 |
+
cv2.destroyAllWindows()
|
| 147 |
+
|
| 148 |
+
def __iter__(self):
|
| 149 |
+
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
|
| 150 |
+
self.count = -1
|
| 151 |
+
return self
|
| 152 |
+
|
| 153 |
+
def __next__(self):
|
| 154 |
+
"""Returns source paths, transformed and original images for processing."""
|
| 155 |
+
self.count += 1
|
| 156 |
+
|
| 157 |
+
images = []
|
| 158 |
+
for i, x in enumerate(self.imgs):
|
| 159 |
+
# Wait until a frame is available in each buffer
|
| 160 |
+
while not x:
|
| 161 |
+
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit
|
| 162 |
+
self.close()
|
| 163 |
+
raise StopIteration
|
| 164 |
+
time.sleep(1 / min(self.fps))
|
| 165 |
+
x = self.imgs[i]
|
| 166 |
+
if not x:
|
| 167 |
+
LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}")
|
| 168 |
+
|
| 169 |
+
# Get and remove the first frame from imgs buffer
|
| 170 |
+
if self.buffer:
|
| 171 |
+
images.append(x.pop(0))
|
| 172 |
+
|
| 173 |
+
# Get the last frame, and clear the rest from the imgs buffer
|
| 174 |
+
else:
|
| 175 |
+
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
|
| 176 |
+
x.clear()
|
| 177 |
+
|
| 178 |
+
return self.sources, images, None, ""
|
| 179 |
+
|
| 180 |
+
def __len__(self):
|
| 181 |
+
"""Return the length of the sources object."""
|
| 182 |
+
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class LoadScreenshots:
|
| 186 |
+
"""
|
| 187 |
+
YOLOv8 screenshot dataloader.
|
| 188 |
+
|
| 189 |
+
This class manages the loading of screenshot images for processing with YOLOv8.
|
| 190 |
+
Suitable for use with `yolo predict source=screen`.
|
| 191 |
+
|
| 192 |
+
Attributes:
|
| 193 |
+
source (str): The source input indicating which screen to capture.
|
| 194 |
+
screen (int): The screen number to capture.
|
| 195 |
+
left (int): The left coordinate for screen capture area.
|
| 196 |
+
top (int): The top coordinate for screen capture area.
|
| 197 |
+
width (int): The width of the screen capture area.
|
| 198 |
+
height (int): The height of the screen capture area.
|
| 199 |
+
mode (str): Set to 'stream' indicating real-time capture.
|
| 200 |
+
frame (int): Counter for captured frames.
|
| 201 |
+
sct (mss.mss): Screen capture object from `mss` library.
|
| 202 |
+
bs (int): Batch size, set to 1.
|
| 203 |
+
monitor (dict): Monitor configuration details.
|
| 204 |
+
|
| 205 |
+
Methods:
|
| 206 |
+
__iter__: Returns an iterator object.
|
| 207 |
+
__next__: Captures the next screenshot and returns it.
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(self, source):
|
| 211 |
+
"""Source = [screen_number left top width height] (pixels)."""
|
| 212 |
+
check_requirements("mss")
|
| 213 |
+
import mss # noqa
|
| 214 |
+
|
| 215 |
+
source, *params = source.split()
|
| 216 |
+
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
| 217 |
+
if len(params) == 1:
|
| 218 |
+
self.screen = int(params[0])
|
| 219 |
+
elif len(params) == 4:
|
| 220 |
+
left, top, width, height = (int(x) for x in params)
|
| 221 |
+
elif len(params) == 5:
|
| 222 |
+
self.screen, left, top, width, height = (int(x) for x in params)
|
| 223 |
+
self.mode = "stream"
|
| 224 |
+
self.frame = 0
|
| 225 |
+
self.sct = mss.mss()
|
| 226 |
+
self.bs = 1
|
| 227 |
+
|
| 228 |
+
# Parse monitor shape
|
| 229 |
+
monitor = self.sct.monitors[self.screen]
|
| 230 |
+
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
| 231 |
+
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
| 232 |
+
self.width = width or monitor["width"]
|
| 233 |
+
self.height = height or monitor["height"]
|
| 234 |
+
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
| 235 |
+
|
| 236 |
+
def __iter__(self):
|
| 237 |
+
"""Returns an iterator of the object."""
|
| 238 |
+
return self
|
| 239 |
+
|
| 240 |
+
def __next__(self):
|
| 241 |
+
"""mss screen capture: get raw pixels from the screen as np array."""
|
| 242 |
+
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
|
| 243 |
+
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
| 244 |
+
|
| 245 |
+
self.frame += 1
|
| 246 |
+
return [str(self.screen)], [im0], None, s # screen, img, vid_cap, string
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class LoadImages:
|
| 250 |
+
"""
|
| 251 |
+
YOLOv8 image/video dataloader.
|
| 252 |
+
|
| 253 |
+
This class manages the loading and pre-processing of image and video data for YOLOv8. It supports loading from
|
| 254 |
+
various formats, including single image files, video files, and lists of image and video paths.
|
| 255 |
+
|
| 256 |
+
Attributes:
|
| 257 |
+
files (list): List of image and video file paths.
|
| 258 |
+
nf (int): Total number of files (images and videos).
|
| 259 |
+
video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
|
| 260 |
+
mode (str): Current mode, 'image' or 'video'.
|
| 261 |
+
vid_stride (int): Stride for video frame-rate, defaults to 1.
|
| 262 |
+
bs (int): Batch size, set to 1 for this class.
|
| 263 |
+
cap (cv2.VideoCapture): Video capture object for OpenCV.
|
| 264 |
+
frame (int): Frame counter for video.
|
| 265 |
+
frames (int): Total number of frames in the video.
|
| 266 |
+
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
|
| 267 |
+
|
| 268 |
+
Methods:
|
| 269 |
+
_new_video(path): Create a new cv2.VideoCapture object for a given video path.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def __init__(self, path, vid_stride=1):
|
| 273 |
+
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
| 274 |
+
parent = None
|
| 275 |
+
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
| 276 |
+
parent = Path(path).parent
|
| 277 |
+
path = Path(path).read_text().splitlines() # list of sources
|
| 278 |
+
files = []
|
| 279 |
+
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
| 280 |
+
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
| 281 |
+
if "*" in a:
|
| 282 |
+
files.extend(sorted(glob.glob(a, recursive=True))) # glob
|
| 283 |
+
elif os.path.isdir(a):
|
| 284 |
+
files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
|
| 285 |
+
elif os.path.isfile(a):
|
| 286 |
+
files.append(a) # files (absolute or relative to CWD)
|
| 287 |
+
elif parent and (parent / p).is_file():
|
| 288 |
+
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
|
| 289 |
+
else:
|
| 290 |
+
raise FileNotFoundError(f"{p} does not exist")
|
| 291 |
+
|
| 292 |
+
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
| 293 |
+
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
| 294 |
+
ni, nv = len(images), len(videos)
|
| 295 |
+
|
| 296 |
+
self.files = images + videos
|
| 297 |
+
self.nf = ni + nv # number of files
|
| 298 |
+
self.video_flag = [False] * ni + [True] * nv
|
| 299 |
+
self.mode = "image"
|
| 300 |
+
self.vid_stride = vid_stride # video frame-rate stride
|
| 301 |
+
self.bs = 1
|
| 302 |
+
if any(videos):
|
| 303 |
+
self._new_video(videos[0]) # new video
|
| 304 |
+
else:
|
| 305 |
+
self.cap = None
|
| 306 |
+
if self.nf == 0:
|
| 307 |
+
raise FileNotFoundError(
|
| 308 |
+
f"No images or videos found in {p}. "
|
| 309 |
+
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def __iter__(self):
|
| 313 |
+
"""Returns an iterator object for VideoStream or ImageFolder."""
|
| 314 |
+
self.count = 0
|
| 315 |
+
return self
|
| 316 |
+
|
| 317 |
+
def __next__(self):
|
| 318 |
+
"""Return next image, path and metadata from dataset."""
|
| 319 |
+
if self.count == self.nf:
|
| 320 |
+
raise StopIteration
|
| 321 |
+
path = self.files[self.count]
|
| 322 |
+
|
| 323 |
+
if self.video_flag[self.count]:
|
| 324 |
+
# Read video
|
| 325 |
+
self.mode = "video"
|
| 326 |
+
for _ in range(self.vid_stride):
|
| 327 |
+
self.cap.grab()
|
| 328 |
+
success, im0 = self.cap.retrieve()
|
| 329 |
+
while not success:
|
| 330 |
+
self.count += 1
|
| 331 |
+
self.cap.release()
|
| 332 |
+
if self.count == self.nf: # last video
|
| 333 |
+
raise StopIteration
|
| 334 |
+
path = self.files[self.count]
|
| 335 |
+
self._new_video(path)
|
| 336 |
+
success, im0 = self.cap.read()
|
| 337 |
+
|
| 338 |
+
self.frame += 1
|
| 339 |
+
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
| 340 |
+
s = f"video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: "
|
| 341 |
+
|
| 342 |
+
else:
|
| 343 |
+
# Read image
|
| 344 |
+
self.count += 1
|
| 345 |
+
im0 = cv2.imread(path) # BGR
|
| 346 |
+
if im0 is None:
|
| 347 |
+
raise FileNotFoundError(f"Image Not Found {path}")
|
| 348 |
+
s = f"image {self.count}/{self.nf} {path}: "
|
| 349 |
+
|
| 350 |
+
return [path], [im0], self.cap, s
|
| 351 |
+
|
| 352 |
+
def _new_video(self, path):
|
| 353 |
+
"""Create a new video capture object."""
|
| 354 |
+
self.frame = 0
|
| 355 |
+
self.cap = cv2.VideoCapture(path)
|
| 356 |
+
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
| 357 |
+
|
| 358 |
+
def __len__(self):
|
| 359 |
+
"""Returns the number of files in the object."""
|
| 360 |
+
return self.nf # number of files
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class LoadPilAndNumpy:
|
| 364 |
+
"""
|
| 365 |
+
Load images from PIL and Numpy arrays for batch processing.
|
| 366 |
+
|
| 367 |
+
This class is designed to manage loading and pre-processing of image data from both PIL and Numpy formats.
|
| 368 |
+
It performs basic validation and format conversion to ensure that the images are in the required format for
|
| 369 |
+
downstream processing.
|
| 370 |
+
|
| 371 |
+
Attributes:
|
| 372 |
+
paths (list): List of image paths or autogenerated filenames.
|
| 373 |
+
im0 (list): List of images stored as Numpy arrays.
|
| 374 |
+
mode (str): Type of data being processed, defaults to 'image'.
|
| 375 |
+
bs (int): Batch size, equivalent to the length of `im0`.
|
| 376 |
+
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
|
| 377 |
+
|
| 378 |
+
Methods:
|
| 379 |
+
_single_check(im): Validate and format a single image to a Numpy array.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def __init__(self, im0):
|
| 383 |
+
"""Initialize PIL and Numpy Dataloader."""
|
| 384 |
+
if not isinstance(im0, list):
|
| 385 |
+
im0 = [im0]
|
| 386 |
+
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
| 387 |
+
self.im0 = [self._single_check(im) for im in im0]
|
| 388 |
+
self.mode = "image"
|
| 389 |
+
# Generate fake paths
|
| 390 |
+
self.bs = len(self.im0)
|
| 391 |
+
|
| 392 |
+
@staticmethod
|
| 393 |
+
def _single_check(im):
|
| 394 |
+
"""Validate and format an image to numpy array."""
|
| 395 |
+
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
| 396 |
+
if isinstance(im, Image.Image):
|
| 397 |
+
if im.mode != "RGB":
|
| 398 |
+
im = im.convert("RGB")
|
| 399 |
+
im = np.asarray(im)[:, :, ::-1]
|
| 400 |
+
im = np.ascontiguousarray(im) # contiguous
|
| 401 |
+
return im
|
| 402 |
+
|
| 403 |
+
def __len__(self):
|
| 404 |
+
"""Returns the length of the 'im0' attribute."""
|
| 405 |
+
return len(self.im0)
|
| 406 |
+
|
| 407 |
+
def __next__(self):
|
| 408 |
+
"""Returns batch paths, images, processed images, None, ''."""
|
| 409 |
+
if self.count == 1: # loop only once as it's batch inference
|
| 410 |
+
raise StopIteration
|
| 411 |
+
self.count += 1
|
| 412 |
+
return self.paths, self.im0, None, ""
|
| 413 |
+
|
| 414 |
+
def __iter__(self):
|
| 415 |
+
"""Enables iteration for class LoadPilAndNumpy."""
|
| 416 |
+
self.count = 0
|
| 417 |
+
return self
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class LoadTensor:
|
| 421 |
+
"""
|
| 422 |
+
Load images from torch.Tensor data.
|
| 423 |
+
|
| 424 |
+
This class manages the loading and pre-processing of image data from PyTorch tensors for further processing.
|
| 425 |
+
|
| 426 |
+
Attributes:
|
| 427 |
+
im0 (torch.Tensor): The input tensor containing the image(s).
|
| 428 |
+
bs (int): Batch size, inferred from the shape of `im0`.
|
| 429 |
+
mode (str): Current mode, set to 'image'.
|
| 430 |
+
paths (list): List of image paths or filenames.
|
| 431 |
+
count (int): Counter for iteration, initialized at 0 during `__iter__()`.
|
| 432 |
+
|
| 433 |
+
Methods:
|
| 434 |
+
_single_check(im, stride): Validate and possibly modify the input tensor.
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
def __init__(self, im0) -> None:
|
| 438 |
+
"""Initialize Tensor Dataloader."""
|
| 439 |
+
self.im0 = self._single_check(im0)
|
| 440 |
+
self.bs = self.im0.shape[0]
|
| 441 |
+
self.mode = "image"
|
| 442 |
+
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _single_check(im, stride=32):
|
| 446 |
+
"""Validate and format an image to torch.Tensor."""
|
| 447 |
+
s = (
|
| 448 |
+
f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
|
| 449 |
+
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
|
| 450 |
+
)
|
| 451 |
+
if len(im.shape) != 4:
|
| 452 |
+
if len(im.shape) != 3:
|
| 453 |
+
raise ValueError(s)
|
| 454 |
+
LOGGER.warning(s)
|
| 455 |
+
im = im.unsqueeze(0)
|
| 456 |
+
if im.shape[2] % stride or im.shape[3] % stride:
|
| 457 |
+
raise ValueError(s)
|
| 458 |
+
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
|
| 459 |
+
LOGGER.warning(
|
| 460 |
+
f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. "
|
| 461 |
+
f"Dividing input by 255."
|
| 462 |
+
)
|
| 463 |
+
im = im.float() / 255.0
|
| 464 |
+
|
| 465 |
+
return im
|
| 466 |
+
|
| 467 |
+
def __iter__(self):
|
| 468 |
+
"""Returns an iterator object."""
|
| 469 |
+
self.count = 0
|
| 470 |
+
return self
|
| 471 |
+
|
| 472 |
+
def __next__(self):
|
| 473 |
+
"""Return next item in the iterator."""
|
| 474 |
+
if self.count == 1:
|
| 475 |
+
raise StopIteration
|
| 476 |
+
self.count += 1
|
| 477 |
+
return self.paths, self.im0, None, ""
|
| 478 |
+
|
| 479 |
+
def __len__(self):
|
| 480 |
+
"""Returns the batch size."""
|
| 481 |
+
return self.bs
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def autocast_list(source):
|
| 485 |
+
"""Merges a list of source of different types into a list of numpy arrays or PIL images."""
|
| 486 |
+
files = []
|
| 487 |
+
for im in source:
|
| 488 |
+
if isinstance(im, (str, Path)): # filename or uri
|
| 489 |
+
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im))
|
| 490 |
+
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
| 491 |
+
files.append(im)
|
| 492 |
+
else:
|
| 493 |
+
raise TypeError(
|
| 494 |
+
f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
|
| 495 |
+
f"See https://docs.ultralytics.com/modes/predict for supported source types."
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
return files
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
LOADERS = LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots # tuple
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def get_best_youtube_url(url, use_pafy=True):
|
| 505 |
+
"""
|
| 506 |
+
Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
|
| 507 |
+
|
| 508 |
+
This function uses the pafy or yt_dlp library to extract the video info from YouTube. It then finds the highest
|
| 509 |
+
quality MP4 format that has video codec but no audio codec, and returns the URL of this video stream.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
url (str): The URL of the YouTube video.
|
| 513 |
+
use_pafy (bool): Use the pafy package, default=True, otherwise use yt_dlp package.
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
| 517 |
+
"""
|
| 518 |
+
if use_pafy:
|
| 519 |
+
check_requirements(("pafy", "youtube_dl==2020.12.2"))
|
| 520 |
+
import pafy # noqa
|
| 521 |
+
|
| 522 |
+
return pafy.new(url).getbestvideo(preftype="mp4").url
|
| 523 |
+
else:
|
| 524 |
+
check_requirements("yt-dlp")
|
| 525 |
+
import yt_dlp
|
| 526 |
+
|
| 527 |
+
with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
|
| 528 |
+
info_dict = ydl.extract_info(url, download=False) # extract info
|
| 529 |
+
for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
|
| 530 |
+
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
|
| 531 |
+
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
|
| 532 |
+
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
|
| 533 |
+
return f.get("url")
|
yolov8_model/ultralytics/data/scripts/download_weights.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 3 |
+
# Download latest models from https://github.com/ultralytics/assets/releases
|
| 4 |
+
# Example usage: bash ultralytics/data/scripts/download_weights.sh
|
| 5 |
+
# parent
|
| 6 |
+
# └── weights
|
| 7 |
+
# ├── yolov8n.pt ← downloads here
|
| 8 |
+
# ├── yolov8s.pt
|
| 9 |
+
# └── ...
|
| 10 |
+
|
| 11 |
+
python - <<EOF
|
| 12 |
+
from ultralytics.utils.downloads import attempt_download_asset
|
| 13 |
+
|
| 14 |
+
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')]
|
| 15 |
+
for x in assets:
|
| 16 |
+
attempt_download_asset(f'weights/{x}')
|
| 17 |
+
|
| 18 |
+
EOF
|
yolov8_model/ultralytics/data/scripts/get_coco.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 3 |
+
# Download COCO 2017 dataset https://cocodataset.org
|
| 4 |
+
# Example usage: bash data/scripts/get_coco.sh
|
| 5 |
+
# parent
|
| 6 |
+
# ├── ultralytics
|
| 7 |
+
# └── datasets
|
| 8 |
+
# └── coco ← downloads here
|
| 9 |
+
|
| 10 |
+
# Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments
|
| 11 |
+
if [ "$#" -gt 0 ]; then
|
| 12 |
+
for opt in "$@"; do
|
| 13 |
+
case "${opt}" in
|
| 14 |
+
--train) train=true ;;
|
| 15 |
+
--val) val=true ;;
|
| 16 |
+
--test) test=true ;;
|
| 17 |
+
--segments) segments=true ;;
|
| 18 |
+
--sama) sama=true ;;
|
| 19 |
+
esac
|
| 20 |
+
done
|
| 21 |
+
else
|
| 22 |
+
train=true
|
| 23 |
+
val=true
|
| 24 |
+
test=false
|
| 25 |
+
segments=false
|
| 26 |
+
sama=false
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
# Download/unzip labels
|
| 30 |
+
d='../datasets' # unzip directory
|
| 31 |
+
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
|
| 32 |
+
if [ "$segments" == "true" ]; then
|
| 33 |
+
f='coco2017labels-segments.zip' # 169 MB
|
| 34 |
+
elif [ "$sama" == "true" ]; then
|
| 35 |
+
f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/
|
| 36 |
+
else
|
| 37 |
+
f='coco2017labels.zip' # 46 MB
|
| 38 |
+
fi
|
| 39 |
+
echo 'Downloading' $url$f ' ...'
|
| 40 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
| 41 |
+
|
| 42 |
+
# Download/unzip images
|
| 43 |
+
d='../datasets/coco/images' # unzip directory
|
| 44 |
+
url=http://images.cocodataset.org/zips/
|
| 45 |
+
if [ "$train" == "true" ]; then
|
| 46 |
+
f='train2017.zip' # 19G, 118k images
|
| 47 |
+
echo 'Downloading' $url$f '...'
|
| 48 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
| 49 |
+
fi
|
| 50 |
+
if [ "$val" == "true" ]; then
|
| 51 |
+
f='val2017.zip' # 1G, 5k images
|
| 52 |
+
echo 'Downloading' $url$f '...'
|
| 53 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
| 54 |
+
fi
|
| 55 |
+
if [ "$test" == "true" ]; then
|
| 56 |
+
f='test2017.zip' # 7G, 41k images (optional)
|
| 57 |
+
echo 'Downloading' $url$f '...'
|
| 58 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
| 59 |
+
fi
|
| 60 |
+
wait # finish background tasks
|
yolov8_model/ultralytics/data/scripts/get_coco128.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 3 |
+
# Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017)
|
| 4 |
+
# Example usage: bash data/scripts/get_coco128.sh
|
| 5 |
+
# parent
|
| 6 |
+
# ├── ultralytics
|
| 7 |
+
# └── datasets
|
| 8 |
+
# └── coco128 ← downloads here
|
| 9 |
+
|
| 10 |
+
# Download/unzip images and labels
|
| 11 |
+
d='../datasets' # unzip directory
|
| 12 |
+
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
|
| 13 |
+
f='coco128.zip' # or 'coco128-segments.zip', 68 MB
|
| 14 |
+
echo 'Downloading' $url$f ' ...'
|
| 15 |
+
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
| 16 |
+
|
| 17 |
+
wait # finish background tasks
|
yolov8_model/ultralytics/data/scripts/get_imagenet.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 3 |
+
# Download ILSVRC2012 ImageNet dataset https://image-net.org
|
| 4 |
+
# Example usage: bash data/scripts/get_imagenet.sh
|
| 5 |
+
# parent
|
| 6 |
+
# ├── ultralytics
|
| 7 |
+
# └── datasets
|
| 8 |
+
# └── imagenet ← downloads here
|
| 9 |
+
|
| 10 |
+
# Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val
|
| 11 |
+
if [ "$#" -gt 0 ]; then
|
| 12 |
+
for opt in "$@"; do
|
| 13 |
+
case "${opt}" in
|
| 14 |
+
--train) train=true ;;
|
| 15 |
+
--val) val=true ;;
|
| 16 |
+
esac
|
| 17 |
+
done
|
| 18 |
+
else
|
| 19 |
+
train=true
|
| 20 |
+
val=true
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
# Make dir
|
| 24 |
+
d='../datasets/imagenet' # unzip directory
|
| 25 |
+
mkdir -p $d && cd $d
|
| 26 |
+
|
| 27 |
+
# Download/unzip train
|
| 28 |
+
if [ "$train" == "true" ]; then
|
| 29 |
+
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images
|
| 30 |
+
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
|
| 31 |
+
tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
|
| 32 |
+
find . -name "*.tar" | while read NAME; do
|
| 33 |
+
mkdir -p "${NAME%.tar}"
|
| 34 |
+
tar -xf "${NAME}" -C "${NAME%.tar}"
|
| 35 |
+
rm -f "${NAME}"
|
| 36 |
+
done
|
| 37 |
+
cd ..
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
# Download/unzip val
|
| 41 |
+
if [ "$val" == "true" ]; then
|
| 42 |
+
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images
|
| 43 |
+
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar
|
| 44 |
+
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
# Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail)
|
| 48 |
+
# rm train/n04266014/n04266014_10835.JPEG
|
| 49 |
+
|
| 50 |
+
# TFRecords (optional)
|
| 51 |
+
# wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt
|
yolov8_model/ultralytics/data/split_dota.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
from glob import glob
|
| 5 |
+
from math import ceil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from yolov8_model.ultralytics.data.utils import exif_size, img2label_paths
|
| 14 |
+
from yolov8_model.ultralytics.utils.checks import check_requirements
|
| 15 |
+
|
| 16 |
+
check_requirements("shapely")
|
| 17 |
+
from shapely.geometry import Polygon
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def bbox_iof(polygon1, bbox2, eps=1e-6):
|
| 21 |
+
"""
|
| 22 |
+
Calculate iofs between bbox1 and bbox2.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
polygon1 (np.ndarray): Polygon coordinates, (n, 8).
|
| 26 |
+
bbox2 (np.ndarray): Bounding boxes, (n ,4).
|
| 27 |
+
"""
|
| 28 |
+
polygon1 = polygon1.reshape(-1, 4, 2)
|
| 29 |
+
lt_point = np.min(polygon1, axis=-2)
|
| 30 |
+
rb_point = np.max(polygon1, axis=-2)
|
| 31 |
+
bbox1 = np.concatenate([lt_point, rb_point], axis=-1)
|
| 32 |
+
|
| 33 |
+
lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])
|
| 34 |
+
rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])
|
| 35 |
+
wh = np.clip(rb - lt, 0, np.inf)
|
| 36 |
+
h_overlaps = wh[..., 0] * wh[..., 1]
|
| 37 |
+
|
| 38 |
+
l, t, r, b = (bbox2[..., i] for i in range(4))
|
| 39 |
+
polygon2 = np.stack([l, t, r, t, r, b, l, b], axis=-1).reshape(-1, 4, 2)
|
| 40 |
+
|
| 41 |
+
sg_polys1 = [Polygon(p) for p in polygon1]
|
| 42 |
+
sg_polys2 = [Polygon(p) for p in polygon2]
|
| 43 |
+
overlaps = np.zeros(h_overlaps.shape)
|
| 44 |
+
for p in zip(*np.nonzero(h_overlaps)):
|
| 45 |
+
overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area
|
| 46 |
+
unions = np.array([p.area for p in sg_polys1], dtype=np.float32)
|
| 47 |
+
unions = unions[..., None]
|
| 48 |
+
|
| 49 |
+
unions = np.clip(unions, eps, np.inf)
|
| 50 |
+
outputs = overlaps / unions
|
| 51 |
+
if outputs.ndim == 1:
|
| 52 |
+
outputs = outputs[..., None]
|
| 53 |
+
return outputs
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_yolo_dota(data_root, split="train"):
|
| 57 |
+
"""
|
| 58 |
+
Load DOTA dataset.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
data_root (str): Data root.
|
| 62 |
+
split (str): The split data set, could be train or val.
|
| 63 |
+
|
| 64 |
+
Notes:
|
| 65 |
+
The directory structure assumed for the DOTA dataset:
|
| 66 |
+
- data_root
|
| 67 |
+
- images
|
| 68 |
+
- train
|
| 69 |
+
- val
|
| 70 |
+
- labels
|
| 71 |
+
- train
|
| 72 |
+
- val
|
| 73 |
+
"""
|
| 74 |
+
assert split in ["train", "val"]
|
| 75 |
+
im_dir = Path(data_root) / "images" / split
|
| 76 |
+
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
| 77 |
+
im_files = glob(str(Path(data_root) / "images" / split / "*"))
|
| 78 |
+
lb_files = img2label_paths(im_files)
|
| 79 |
+
annos = []
|
| 80 |
+
for im_file, lb_file in zip(im_files, lb_files):
|
| 81 |
+
w, h = exif_size(Image.open(im_file))
|
| 82 |
+
with open(lb_file) as f:
|
| 83 |
+
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
| 84 |
+
lb = np.array(lb, dtype=np.float32)
|
| 85 |
+
annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))
|
| 86 |
+
return annos
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_windows(im_size, crop_sizes=[1024], gaps=[200], im_rate_thr=0.6, eps=0.01):
|
| 90 |
+
"""
|
| 91 |
+
Get the coordinates of windows.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
im_size (tuple): Original image size, (h, w).
|
| 95 |
+
crop_sizes (List(int)): Crop size of windows.
|
| 96 |
+
gaps (List(int)): Gap between crops.
|
| 97 |
+
im_rate_thr (float): Threshold of windows areas divided by image ares.
|
| 98 |
+
"""
|
| 99 |
+
h, w = im_size
|
| 100 |
+
windows = []
|
| 101 |
+
for crop_size, gap in zip(crop_sizes, gaps):
|
| 102 |
+
assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
|
| 103 |
+
step = crop_size - gap
|
| 104 |
+
|
| 105 |
+
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
|
| 106 |
+
xs = [step * i for i in range(xn)]
|
| 107 |
+
if len(xs) > 1 and xs[-1] + crop_size > w:
|
| 108 |
+
xs[-1] = w - crop_size
|
| 109 |
+
|
| 110 |
+
yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
|
| 111 |
+
ys = [step * i for i in range(yn)]
|
| 112 |
+
if len(ys) > 1 and ys[-1] + crop_size > h:
|
| 113 |
+
ys[-1] = h - crop_size
|
| 114 |
+
|
| 115 |
+
start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
|
| 116 |
+
stop = start + crop_size
|
| 117 |
+
windows.append(np.concatenate([start, stop], axis=1))
|
| 118 |
+
windows = np.concatenate(windows, axis=0)
|
| 119 |
+
|
| 120 |
+
im_in_wins = windows.copy()
|
| 121 |
+
im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
|
| 122 |
+
im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
|
| 123 |
+
im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
|
| 124 |
+
win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
|
| 125 |
+
im_rates = im_areas / win_areas
|
| 126 |
+
if not (im_rates > im_rate_thr).any():
|
| 127 |
+
max_rate = im_rates.max()
|
| 128 |
+
im_rates[abs(im_rates - max_rate) < eps] = 1
|
| 129 |
+
return windows[im_rates > im_rate_thr]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_window_obj(anno, windows, iof_thr=0.7):
|
| 133 |
+
"""Get objects for each window."""
|
| 134 |
+
h, w = anno["ori_size"]
|
| 135 |
+
label = anno["label"]
|
| 136 |
+
if len(label):
|
| 137 |
+
label[:, 1::2] *= w
|
| 138 |
+
label[:, 2::2] *= h
|
| 139 |
+
iofs = bbox_iof(label[:, 1:], windows)
|
| 140 |
+
# Unnormalized and misaligned coordinates
|
| 141 |
+
return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
|
| 142 |
+
else:
|
| 143 |
+
return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
|
| 147 |
+
"""
|
| 148 |
+
Crop images and save new labels.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
|
| 152 |
+
windows (list): A list of windows coordinates.
|
| 153 |
+
window_objs (list): A list of labels inside each window.
|
| 154 |
+
im_dir (str): The output directory path of images.
|
| 155 |
+
lb_dir (str): The output directory path of labels.
|
| 156 |
+
|
| 157 |
+
Notes:
|
| 158 |
+
The directory structure assumed for the DOTA dataset:
|
| 159 |
+
- data_root
|
| 160 |
+
- images
|
| 161 |
+
- train
|
| 162 |
+
- val
|
| 163 |
+
- labels
|
| 164 |
+
- train
|
| 165 |
+
- val
|
| 166 |
+
"""
|
| 167 |
+
im = cv2.imread(anno["filepath"])
|
| 168 |
+
name = Path(anno["filepath"]).stem
|
| 169 |
+
for i, window in enumerate(windows):
|
| 170 |
+
x_start, y_start, x_stop, y_stop = window.tolist()
|
| 171 |
+
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
|
| 172 |
+
patch_im = im[y_start:y_stop, x_start:x_stop]
|
| 173 |
+
ph, pw = patch_im.shape[:2]
|
| 174 |
+
|
| 175 |
+
cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im)
|
| 176 |
+
label = window_objs[i]
|
| 177 |
+
if len(label) == 0:
|
| 178 |
+
continue
|
| 179 |
+
label[:, 1::2] -= x_start
|
| 180 |
+
label[:, 2::2] -= y_start
|
| 181 |
+
label[:, 1::2] /= pw
|
| 182 |
+
label[:, 2::2] /= ph
|
| 183 |
+
|
| 184 |
+
with open(Path(lb_dir) / f"{new_name}.txt", "w") as f:
|
| 185 |
+
for lb in label:
|
| 186 |
+
formatted_coords = ["{:.6g}".format(coord) for coord in lb[1:]]
|
| 187 |
+
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=[1024], gaps=[200]):
|
| 191 |
+
"""
|
| 192 |
+
Split both images and labels.
|
| 193 |
+
|
| 194 |
+
Notes:
|
| 195 |
+
The directory structure assumed for the DOTA dataset:
|
| 196 |
+
- data_root
|
| 197 |
+
- images
|
| 198 |
+
- split
|
| 199 |
+
- labels
|
| 200 |
+
- split
|
| 201 |
+
and the output directory structure is:
|
| 202 |
+
- save_dir
|
| 203 |
+
- images
|
| 204 |
+
- split
|
| 205 |
+
- labels
|
| 206 |
+
- split
|
| 207 |
+
"""
|
| 208 |
+
im_dir = Path(save_dir) / "images" / split
|
| 209 |
+
im_dir.mkdir(parents=True, exist_ok=True)
|
| 210 |
+
lb_dir = Path(save_dir) / "labels" / split
|
| 211 |
+
lb_dir.mkdir(parents=True, exist_ok=True)
|
| 212 |
+
|
| 213 |
+
annos = load_yolo_dota(data_root, split=split)
|
| 214 |
+
for anno in tqdm(annos, total=len(annos), desc=split):
|
| 215 |
+
windows = get_windows(anno["ori_size"], crop_sizes, gaps)
|
| 216 |
+
window_objs = get_window_obj(anno, windows)
|
| 217 |
+
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
|
| 221 |
+
"""
|
| 222 |
+
Split train and val set of DOTA.
|
| 223 |
+
|
| 224 |
+
Notes:
|
| 225 |
+
The directory structure assumed for the DOTA dataset:
|
| 226 |
+
- data_root
|
| 227 |
+
- images
|
| 228 |
+
- train
|
| 229 |
+
- val
|
| 230 |
+
- labels
|
| 231 |
+
- train
|
| 232 |
+
- val
|
| 233 |
+
and the output directory structure is:
|
| 234 |
+
- save_dir
|
| 235 |
+
- images
|
| 236 |
+
- train
|
| 237 |
+
- val
|
| 238 |
+
- labels
|
| 239 |
+
- train
|
| 240 |
+
- val
|
| 241 |
+
"""
|
| 242 |
+
crop_sizes, gaps = [], []
|
| 243 |
+
for r in rates:
|
| 244 |
+
crop_sizes.append(int(crop_size / r))
|
| 245 |
+
gaps.append(int(gap / r))
|
| 246 |
+
for split in ["train", "val"]:
|
| 247 |
+
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
|
| 251 |
+
"""
|
| 252 |
+
Split test set of DOTA, labels are not included within this set.
|
| 253 |
+
|
| 254 |
+
Notes:
|
| 255 |
+
The directory structure assumed for the DOTA dataset:
|
| 256 |
+
- data_root
|
| 257 |
+
- images
|
| 258 |
+
- test
|
| 259 |
+
and the output directory structure is:
|
| 260 |
+
- save_dir
|
| 261 |
+
- images
|
| 262 |
+
- test
|
| 263 |
+
"""
|
| 264 |
+
crop_sizes, gaps = [], []
|
| 265 |
+
for r in rates:
|
| 266 |
+
crop_sizes.append(int(crop_size / r))
|
| 267 |
+
gaps.append(int(gap / r))
|
| 268 |
+
save_dir = Path(save_dir) / "images" / "test"
|
| 269 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 270 |
+
|
| 271 |
+
im_dir = Path(data_root) / "images" / "test"
|
| 272 |
+
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
| 273 |
+
im_files = glob(str(im_dir / "*"))
|
| 274 |
+
for im_file in tqdm(im_files, total=len(im_files), desc="test"):
|
| 275 |
+
w, h = exif_size(Image.open(im_file))
|
| 276 |
+
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
|
| 277 |
+
im = cv2.imread(im_file)
|
| 278 |
+
name = Path(im_file).stem
|
| 279 |
+
for window in windows:
|
| 280 |
+
x_start, y_start, x_stop, y_stop = window.tolist()
|
| 281 |
+
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
|
| 282 |
+
patch_im = im[y_start:y_stop, x_start:x_stop]
|
| 283 |
+
cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split")
|
| 288 |
+
split_test(data_root="DOTAv2", save_dir="DOTAv2-split")
|
yolov8_model/ultralytics/data/utils.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import hashlib
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import subprocess
|
| 9 |
+
import time
|
| 10 |
+
import zipfile
|
| 11 |
+
from multiprocessing.pool import ThreadPool
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from tarfile import is_tarfile
|
| 14 |
+
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
from PIL import Image, ImageOps
|
| 18 |
+
|
| 19 |
+
from yolov8_model.ultralytics.nn.autobackend import check_class_names
|
| 20 |
+
from yolov8_model.ultralytics.utils import (
|
| 21 |
+
DATASETS_DIR,
|
| 22 |
+
LOGGER,
|
| 23 |
+
NUM_THREADS,
|
| 24 |
+
ROOT,
|
| 25 |
+
SETTINGS_YAML,
|
| 26 |
+
TQDM,
|
| 27 |
+
clean_url,
|
| 28 |
+
colorstr,
|
| 29 |
+
emojis,
|
| 30 |
+
yaml_load,
|
| 31 |
+
yaml_save,
|
| 32 |
+
)
|
| 33 |
+
from yolov8_model.ultralytics.utils.checks import check_file, check_font, is_ascii
|
| 34 |
+
from yolov8_model.ultralytics.utils.downloads import download, safe_download, unzip_file
|
| 35 |
+
from yolov8_model.ultralytics.utils.ops import segments2boxes
|
| 36 |
+
|
| 37 |
+
HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
|
| 38 |
+
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
|
| 39 |
+
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
|
| 40 |
+
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def img2label_paths(img_paths):
|
| 44 |
+
"""Define label paths as a function of image paths."""
|
| 45 |
+
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
| 46 |
+
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_hash(paths):
|
| 50 |
+
"""Returns a single hash value of a list of paths (files or dirs)."""
|
| 51 |
+
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
| 52 |
+
h = hashlib.sha256(str(size).encode()) # hash sizes
|
| 53 |
+
h.update("".join(paths).encode()) # hash paths
|
| 54 |
+
return h.hexdigest() # return hash
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def exif_size(img: Image.Image):
|
| 58 |
+
"""Returns exif-corrected PIL size."""
|
| 59 |
+
s = img.size # (width, height)
|
| 60 |
+
if img.format == "JPEG": # only support JPEG images
|
| 61 |
+
with contextlib.suppress(Exception):
|
| 62 |
+
exif = img.getexif()
|
| 63 |
+
if exif:
|
| 64 |
+
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
| 65 |
+
if rotation in [6, 8]: # rotation 270 or 90
|
| 66 |
+
s = s[1], s[0]
|
| 67 |
+
return s
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def verify_image(args):
|
| 71 |
+
"""Verify one image."""
|
| 72 |
+
(im_file, cls), prefix = args
|
| 73 |
+
# Number (found, corrupt), message
|
| 74 |
+
nf, nc, msg = 0, 0, ""
|
| 75 |
+
try:
|
| 76 |
+
im = Image.open(im_file)
|
| 77 |
+
im.verify() # PIL verify
|
| 78 |
+
shape = exif_size(im) # image size
|
| 79 |
+
shape = (shape[1], shape[0]) # hw
|
| 80 |
+
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
| 81 |
+
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
| 82 |
+
if im.format.lower() in ("jpg", "jpeg"):
|
| 83 |
+
with open(im_file, "rb") as f:
|
| 84 |
+
f.seek(-2, 2)
|
| 85 |
+
if f.read() != b"\xff\xd9": # corrupt JPEG
|
| 86 |
+
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
| 87 |
+
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
| 88 |
+
nf = 1
|
| 89 |
+
except Exception as e:
|
| 90 |
+
nc = 1
|
| 91 |
+
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
| 92 |
+
return (im_file, cls), nf, nc, msg
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def verify_image_label(args):
|
| 96 |
+
"""Verify one image-label pair."""
|
| 97 |
+
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
| 98 |
+
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
| 99 |
+
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
| 100 |
+
try:
|
| 101 |
+
# Verify images
|
| 102 |
+
im = Image.open(im_file)
|
| 103 |
+
im.verify() # PIL verify
|
| 104 |
+
shape = exif_size(im) # image size
|
| 105 |
+
shape = (shape[1], shape[0]) # hw
|
| 106 |
+
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
| 107 |
+
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
| 108 |
+
if im.format.lower() in ("jpg", "jpeg"):
|
| 109 |
+
with open(im_file, "rb") as f:
|
| 110 |
+
f.seek(-2, 2)
|
| 111 |
+
if f.read() != b"\xff\xd9": # corrupt JPEG
|
| 112 |
+
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
| 113 |
+
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
| 114 |
+
|
| 115 |
+
# Verify labels
|
| 116 |
+
if os.path.isfile(lb_file):
|
| 117 |
+
nf = 1 # label found
|
| 118 |
+
with open(lb_file) as f:
|
| 119 |
+
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
| 120 |
+
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
|
| 121 |
+
classes = np.array([x[0] for x in lb], dtype=np.float32)
|
| 122 |
+
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
|
| 123 |
+
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
| 124 |
+
lb = np.array(lb, dtype=np.float32)
|
| 125 |
+
nl = len(lb)
|
| 126 |
+
if nl:
|
| 127 |
+
if keypoint:
|
| 128 |
+
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
|
| 129 |
+
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
|
| 130 |
+
else:
|
| 131 |
+
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
| 132 |
+
points = lb[:, 1:]
|
| 133 |
+
assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
|
| 134 |
+
assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
|
| 135 |
+
|
| 136 |
+
# All labels
|
| 137 |
+
max_cls = lb[:, 0].max() # max label count
|
| 138 |
+
assert max_cls <= num_cls, (
|
| 139 |
+
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
|
| 140 |
+
f"Possible class labels are 0-{num_cls - 1}"
|
| 141 |
+
)
|
| 142 |
+
_, i = np.unique(lb, axis=0, return_index=True)
|
| 143 |
+
if len(i) < nl: # duplicate row check
|
| 144 |
+
lb = lb[i] # remove duplicates
|
| 145 |
+
if segments:
|
| 146 |
+
segments = [segments[x] for x in i]
|
| 147 |
+
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
|
| 148 |
+
else:
|
| 149 |
+
ne = 1 # label empty
|
| 150 |
+
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
|
| 151 |
+
else:
|
| 152 |
+
nm = 1 # label missing
|
| 153 |
+
lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
|
| 154 |
+
if keypoint:
|
| 155 |
+
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
|
| 156 |
+
if ndim == 2:
|
| 157 |
+
kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
|
| 158 |
+
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
|
| 159 |
+
lb = lb[:, :5]
|
| 160 |
+
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
| 161 |
+
except Exception as e:
|
| 162 |
+
nc = 1
|
| 163 |
+
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
| 164 |
+
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
|
| 168 |
+
"""
|
| 169 |
+
Convert a list of polygons to a binary mask of the specified image size.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
imgsz (tuple): The size of the image as (height, width).
|
| 173 |
+
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
|
| 174 |
+
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
| 175 |
+
color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
|
| 176 |
+
downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
(np.ndarray): A binary mask of the specified image size with the polygons filled in.
|
| 180 |
+
"""
|
| 181 |
+
mask = np.zeros(imgsz, dtype=np.uint8)
|
| 182 |
+
polygons = np.asarray(polygons, dtype=np.int32)
|
| 183 |
+
polygons = polygons.reshape((polygons.shape[0], -1, 2))
|
| 184 |
+
cv2.fillPoly(mask, polygons, color=color)
|
| 185 |
+
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
|
| 186 |
+
# Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
|
| 187 |
+
return cv2.resize(mask, (nw, nh))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
|
| 191 |
+
"""
|
| 192 |
+
Convert a list of polygons to a set of binary masks of the specified image size.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
imgsz (tuple): The size of the image as (height, width).
|
| 196 |
+
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
|
| 197 |
+
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
| 198 |
+
color (int): The color value to fill in the polygons on the masks.
|
| 199 |
+
downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
(np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
|
| 203 |
+
"""
|
| 204 |
+
return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
| 208 |
+
"""Return a (640, 640) overlap mask."""
|
| 209 |
+
masks = np.zeros(
|
| 210 |
+
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
| 211 |
+
dtype=np.int32 if len(segments) > 255 else np.uint8,
|
| 212 |
+
)
|
| 213 |
+
areas = []
|
| 214 |
+
ms = []
|
| 215 |
+
for si in range(len(segments)):
|
| 216 |
+
mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
|
| 217 |
+
ms.append(mask)
|
| 218 |
+
areas.append(mask.sum())
|
| 219 |
+
areas = np.asarray(areas)
|
| 220 |
+
index = np.argsort(-areas)
|
| 221 |
+
ms = np.array(ms)[index]
|
| 222 |
+
for i in range(len(segments)):
|
| 223 |
+
mask = ms[i] * (i + 1)
|
| 224 |
+
masks = masks + mask
|
| 225 |
+
masks = np.clip(masks, a_min=0, a_max=i + 1)
|
| 226 |
+
return masks, index
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def find_dataset_yaml(path: Path) -> Path:
|
| 230 |
+
"""
|
| 231 |
+
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
|
| 232 |
+
|
| 233 |
+
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
|
| 234 |
+
performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
|
| 235 |
+
is raised if no YAML file is found or if multiple YAML files are found.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
path (Path): The directory path to search for the YAML file.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
(Path): The path of the found YAML file.
|
| 242 |
+
"""
|
| 243 |
+
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
|
| 244 |
+
assert files, f"No YAML file found in '{path.resolve()}'"
|
| 245 |
+
if len(files) > 1:
|
| 246 |
+
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
|
| 247 |
+
assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
|
| 248 |
+
return files[0]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def check_det_dataset(dataset, autodownload=True):
|
| 252 |
+
"""
|
| 253 |
+
Download, verify, and/or unzip a dataset if not found locally.
|
| 254 |
+
|
| 255 |
+
This function checks the availability of a specified dataset, and if not found, it has the option to download and
|
| 256 |
+
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
|
| 257 |
+
resolves paths related to the dataset.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
|
| 261 |
+
autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
(dict): Parsed dataset information and paths.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
file = check_file(dataset)
|
| 268 |
+
|
| 269 |
+
# Download (optional)
|
| 270 |
+
extract_dir = ""
|
| 271 |
+
if zipfile.is_zipfile(file) or is_tarfile(file):
|
| 272 |
+
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
| 273 |
+
file = find_dataset_yaml(DATASETS_DIR / new_dir)
|
| 274 |
+
extract_dir, autodownload = file.parent, False
|
| 275 |
+
|
| 276 |
+
# Read YAML
|
| 277 |
+
data = yaml_load(file, append_filename=True) # dictionary
|
| 278 |
+
|
| 279 |
+
# Checks
|
| 280 |
+
for k in "train", "val":
|
| 281 |
+
if k not in data:
|
| 282 |
+
if k != "val" or "validation" not in data:
|
| 283 |
+
raise SyntaxError(
|
| 284 |
+
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
|
| 285 |
+
)
|
| 286 |
+
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
|
| 287 |
+
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
|
| 288 |
+
if "names" not in data and "nc" not in data:
|
| 289 |
+
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
| 290 |
+
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
|
| 291 |
+
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
| 292 |
+
if "names" not in data:
|
| 293 |
+
data["names"] = [f"class_{i}" for i in range(data["nc"])]
|
| 294 |
+
else:
|
| 295 |
+
data["nc"] = len(data["names"])
|
| 296 |
+
|
| 297 |
+
data["names"] = check_class_names(data["names"])
|
| 298 |
+
|
| 299 |
+
# Resolve paths
|
| 300 |
+
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
|
| 301 |
+
if not path.is_absolute():
|
| 302 |
+
path = (DATASETS_DIR / path).resolve()
|
| 303 |
+
|
| 304 |
+
# Set paths
|
| 305 |
+
data["path"] = path # download scripts
|
| 306 |
+
for k in "train", "val", "test":
|
| 307 |
+
if data.get(k): # prepend path
|
| 308 |
+
if isinstance(data[k], str):
|
| 309 |
+
x = (path / data[k]).resolve()
|
| 310 |
+
if not x.exists() and data[k].startswith("../"):
|
| 311 |
+
x = (path / data[k][3:]).resolve()
|
| 312 |
+
data[k] = str(x)
|
| 313 |
+
else:
|
| 314 |
+
data[k] = [str((path / x).resolve()) for x in data[k]]
|
| 315 |
+
|
| 316 |
+
# Parse YAML
|
| 317 |
+
val, s = (data.get(x) for x in ("val", "download"))
|
| 318 |
+
if val:
|
| 319 |
+
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
| 320 |
+
if not all(x.exists() for x in val):
|
| 321 |
+
name = clean_url(dataset) # dataset name with URL auth stripped
|
| 322 |
+
m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
|
| 323 |
+
if s and autodownload:
|
| 324 |
+
LOGGER.warning(m)
|
| 325 |
+
else:
|
| 326 |
+
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
|
| 327 |
+
raise FileNotFoundError(m)
|
| 328 |
+
t = time.time()
|
| 329 |
+
r = None # success
|
| 330 |
+
if s.startswith("http") and s.endswith(".zip"): # URL
|
| 331 |
+
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
| 332 |
+
elif s.startswith("bash "): # bash script
|
| 333 |
+
LOGGER.info(f"Running {s} ...")
|
| 334 |
+
r = os.system(s)
|
| 335 |
+
else: # python script
|
| 336 |
+
exec(s, {"yaml": data})
|
| 337 |
+
dt = f"({round(time.time() - t, 1)}s)"
|
| 338 |
+
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
| 339 |
+
LOGGER.info(f"Dataset download {s}\n")
|
| 340 |
+
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
| 341 |
+
|
| 342 |
+
return data # dictionary
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def check_cls_dataset(dataset, split=""):
|
| 346 |
+
"""
|
| 347 |
+
Checks a classification dataset such as Imagenet.
|
| 348 |
+
|
| 349 |
+
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
|
| 350 |
+
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
dataset (str | Path): The name of the dataset.
|
| 354 |
+
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
(dict): A dictionary containing the following keys:
|
| 358 |
+
- 'train' (Path): The directory path containing the training set of the dataset.
|
| 359 |
+
- 'val' (Path): The directory path containing the validation set of the dataset.
|
| 360 |
+
- 'test' (Path): The directory path containing the test set of the dataset.
|
| 361 |
+
- 'nc' (int): The number of classes in the dataset.
|
| 362 |
+
- 'names' (dict): A dictionary of class names in the dataset.
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
# Download (optional if dataset=https://file.zip is passed directly)
|
| 366 |
+
if str(dataset).startswith(("http:/", "https:/")):
|
| 367 |
+
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
| 368 |
+
|
| 369 |
+
dataset = Path(dataset)
|
| 370 |
+
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
| 371 |
+
if not data_dir.is_dir():
|
| 372 |
+
LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
|
| 373 |
+
t = time.time()
|
| 374 |
+
if str(dataset) == "imagenet":
|
| 375 |
+
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
| 376 |
+
else:
|
| 377 |
+
url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
|
| 378 |
+
download(url, dir=data_dir.parent)
|
| 379 |
+
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
| 380 |
+
LOGGER.info(s)
|
| 381 |
+
train_set = data_dir / "train"
|
| 382 |
+
val_set = (
|
| 383 |
+
data_dir / "val"
|
| 384 |
+
if (data_dir / "val").exists()
|
| 385 |
+
else data_dir / "validation"
|
| 386 |
+
if (data_dir / "validation").exists()
|
| 387 |
+
else None
|
| 388 |
+
) # data/test or data/val
|
| 389 |
+
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
|
| 390 |
+
if split == "val" and not val_set:
|
| 391 |
+
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
| 392 |
+
elif split == "test" and not test_set:
|
| 393 |
+
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
| 394 |
+
|
| 395 |
+
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
|
| 396 |
+
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
|
| 397 |
+
names = dict(enumerate(sorted(names)))
|
| 398 |
+
|
| 399 |
+
# Print to console
|
| 400 |
+
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
|
| 401 |
+
prefix = f'{colorstr(f"{k}:")} {v}...'
|
| 402 |
+
if v is None:
|
| 403 |
+
LOGGER.info(prefix)
|
| 404 |
+
else:
|
| 405 |
+
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
|
| 406 |
+
nf = len(files) # number of files
|
| 407 |
+
nd = len({file.parent for file in files}) # number of directories
|
| 408 |
+
if nf == 0:
|
| 409 |
+
if k == "train":
|
| 410 |
+
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
|
| 411 |
+
else:
|
| 412 |
+
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
|
| 413 |
+
elif nd != nc:
|
| 414 |
+
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
|
| 415 |
+
else:
|
| 416 |
+
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
|
| 417 |
+
|
| 418 |
+
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class HUBDatasetStats:
|
| 422 |
+
"""
|
| 423 |
+
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
|
| 427 |
+
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
|
| 428 |
+
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
|
| 429 |
+
|
| 430 |
+
Example:
|
| 431 |
+
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
| 432 |
+
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
| 433 |
+
```python
|
| 434 |
+
from ultralytics.data.utils import HUBDatasetStats
|
| 435 |
+
|
| 436 |
+
stats = HUBDatasetStats('path/to/coco8.zip', task='detect') # detect dataset
|
| 437 |
+
stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment') # segment dataset
|
| 438 |
+
stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose') # pose dataset
|
| 439 |
+
stats = HUBDatasetStats('path/to/imagenet10.zip', task='classify') # classification dataset
|
| 440 |
+
|
| 441 |
+
stats.get_json(save=True)
|
| 442 |
+
stats.process_images()
|
| 443 |
+
```
|
| 444 |
+
"""
|
| 445 |
+
|
| 446 |
+
def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
|
| 447 |
+
"""Initialize class."""
|
| 448 |
+
path = Path(path).resolve()
|
| 449 |
+
LOGGER.info(f"Starting HUB dataset checks for {path}....")
|
| 450 |
+
|
| 451 |
+
self.task = task # detect, segment, pose, classify
|
| 452 |
+
if self.task == "classify":
|
| 453 |
+
unzip_dir = unzip_file(path)
|
| 454 |
+
data = check_cls_dataset(unzip_dir)
|
| 455 |
+
data["path"] = unzip_dir
|
| 456 |
+
else: # detect, segment, pose
|
| 457 |
+
_, data_dir, yaml_path = self._unzip(Path(path))
|
| 458 |
+
try:
|
| 459 |
+
# Load YAML with checks
|
| 460 |
+
data = yaml_load(yaml_path)
|
| 461 |
+
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
|
| 462 |
+
yaml_save(yaml_path, data)
|
| 463 |
+
data = check_det_dataset(yaml_path, autodownload) # dict
|
| 464 |
+
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
|
| 465 |
+
except Exception as e:
|
| 466 |
+
raise Exception("error/HUB/dataset_stats/init") from e
|
| 467 |
+
|
| 468 |
+
self.hub_dir = Path(f'{data["path"]}-hub')
|
| 469 |
+
self.im_dir = self.hub_dir / "images"
|
| 470 |
+
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
| 471 |
+
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
|
| 472 |
+
self.data = data
|
| 473 |
+
|
| 474 |
+
@staticmethod
|
| 475 |
+
def _unzip(path):
|
| 476 |
+
"""Unzip data.zip."""
|
| 477 |
+
if not str(path).endswith(".zip"): # path is data.yaml
|
| 478 |
+
return False, None, path
|
| 479 |
+
unzip_dir = unzip_file(path, path=path.parent)
|
| 480 |
+
assert unzip_dir.is_dir(), (
|
| 481 |
+
f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
|
| 482 |
+
)
|
| 483 |
+
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
| 484 |
+
|
| 485 |
+
def _hub_ops(self, f):
|
| 486 |
+
"""Saves a compressed image for HUB previews."""
|
| 487 |
+
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
| 488 |
+
|
| 489 |
+
def get_json(self, save=False, verbose=False):
|
| 490 |
+
"""Return dataset JSON for Ultralytics HUB."""
|
| 491 |
+
|
| 492 |
+
def _round(labels):
|
| 493 |
+
"""Update labels to integer class and 4 decimal place floats."""
|
| 494 |
+
if self.task == "detect":
|
| 495 |
+
coordinates = labels["bboxes"]
|
| 496 |
+
elif self.task == "segment":
|
| 497 |
+
coordinates = [x.flatten() for x in labels["segments"]]
|
| 498 |
+
elif self.task == "pose":
|
| 499 |
+
n = labels["keypoints"].shape[0]
|
| 500 |
+
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
|
| 501 |
+
else:
|
| 502 |
+
raise ValueError("Undefined dataset task.")
|
| 503 |
+
zipped = zip(labels["cls"], coordinates)
|
| 504 |
+
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
| 505 |
+
|
| 506 |
+
for split in "train", "val", "test":
|
| 507 |
+
self.stats[split] = None # predefine
|
| 508 |
+
path = self.data.get(split)
|
| 509 |
+
|
| 510 |
+
# Check split
|
| 511 |
+
if path is None: # no split
|
| 512 |
+
continue
|
| 513 |
+
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
|
| 514 |
+
if not files: # no images
|
| 515 |
+
continue
|
| 516 |
+
|
| 517 |
+
# Get dataset statistics
|
| 518 |
+
if self.task == "classify":
|
| 519 |
+
from torchvision.datasets import ImageFolder
|
| 520 |
+
|
| 521 |
+
dataset = ImageFolder(self.data[split])
|
| 522 |
+
|
| 523 |
+
x = np.zeros(len(dataset.classes)).astype(int)
|
| 524 |
+
for im in dataset.imgs:
|
| 525 |
+
x[im[1]] += 1
|
| 526 |
+
|
| 527 |
+
self.stats[split] = {
|
| 528 |
+
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
|
| 529 |
+
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
|
| 530 |
+
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
|
| 531 |
+
}
|
| 532 |
+
else:
|
| 533 |
+
from yolov8_model.ultralytics.data import YOLODataset
|
| 534 |
+
|
| 535 |
+
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
|
| 536 |
+
x = np.array(
|
| 537 |
+
[
|
| 538 |
+
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
|
| 539 |
+
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
|
| 540 |
+
]
|
| 541 |
+
) # shape(128x80)
|
| 542 |
+
self.stats[split] = {
|
| 543 |
+
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
|
| 544 |
+
"image_stats": {
|
| 545 |
+
"total": len(dataset),
|
| 546 |
+
"unlabelled": int(np.all(x == 0, 1).sum()),
|
| 547 |
+
"per_class": (x > 0).sum(0).tolist(),
|
| 548 |
+
},
|
| 549 |
+
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
# Save, print and return
|
| 553 |
+
if save:
|
| 554 |
+
stats_path = self.hub_dir / "stats.json"
|
| 555 |
+
LOGGER.info(f"Saving {stats_path.resolve()}...")
|
| 556 |
+
with open(stats_path, "w") as f:
|
| 557 |
+
json.dump(self.stats, f) # save stats.json
|
| 558 |
+
if verbose:
|
| 559 |
+
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
| 560 |
+
return self.stats
|
| 561 |
+
|
| 562 |
+
def process_images(self):
|
| 563 |
+
"""Compress images for Ultralytics HUB."""
|
| 564 |
+
from yolov8_model.ultralytics.data import YOLODataset # ClassificationDataset
|
| 565 |
+
|
| 566 |
+
for split in "train", "val", "test":
|
| 567 |
+
if self.data.get(split) is None:
|
| 568 |
+
continue
|
| 569 |
+
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
| 570 |
+
with ThreadPool(NUM_THREADS) as pool:
|
| 571 |
+
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
|
| 572 |
+
pass
|
| 573 |
+
LOGGER.info(f"Done. All images saved to {self.im_dir}")
|
| 574 |
+
return self.im_dir
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
| 578 |
+
"""
|
| 579 |
+
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
|
| 580 |
+
Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
|
| 581 |
+
resized.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
f (str): The path to the input image file.
|
| 585 |
+
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
|
| 586 |
+
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
|
| 587 |
+
quality (int, optional): The image compression quality as a percentage. Default is 50%.
|
| 588 |
+
|
| 589 |
+
Example:
|
| 590 |
+
```python
|
| 591 |
+
from pathlib import Path
|
| 592 |
+
from ultralytics.data.utils import compress_one_image
|
| 593 |
+
|
| 594 |
+
for f in Path('path/to/dataset').rglob('*.jpg'):
|
| 595 |
+
compress_one_image(f)
|
| 596 |
+
```
|
| 597 |
+
"""
|
| 598 |
+
|
| 599 |
+
try: # use PIL
|
| 600 |
+
im = Image.open(f)
|
| 601 |
+
r = max_dim / max(im.height, im.width) # ratio
|
| 602 |
+
if r < 1.0: # image too large
|
| 603 |
+
im = im.resize((int(im.width * r), int(im.height * r)))
|
| 604 |
+
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
|
| 605 |
+
except Exception as e: # use OpenCV
|
| 606 |
+
LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
|
| 607 |
+
im = cv2.imread(f)
|
| 608 |
+
im_height, im_width = im.shape[:2]
|
| 609 |
+
r = max_dim / max(im_height, im_width) # ratio
|
| 610 |
+
if r < 1.0: # image too large
|
| 611 |
+
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
| 612 |
+
cv2.imwrite(str(f_new or f), im)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
|
| 616 |
+
"""
|
| 617 |
+
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
| 618 |
+
|
| 619 |
+
Args:
|
| 620 |
+
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
|
| 621 |
+
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
|
| 622 |
+
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
|
| 623 |
+
|
| 624 |
+
Example:
|
| 625 |
+
```python
|
| 626 |
+
from ultralytics.data.utils import autosplit
|
| 627 |
+
|
| 628 |
+
autosplit()
|
| 629 |
+
```
|
| 630 |
+
"""
|
| 631 |
+
|
| 632 |
+
path = Path(path) # images dir
|
| 633 |
+
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
| 634 |
+
n = len(files) # number of files
|
| 635 |
+
random.seed(0) # for reproducibility
|
| 636 |
+
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
| 637 |
+
|
| 638 |
+
txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
|
| 639 |
+
for x in txt:
|
| 640 |
+
if (path.parent / x).exists():
|
| 641 |
+
(path.parent / x).unlink() # remove existing
|
| 642 |
+
|
| 643 |
+
LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
|
| 644 |
+
for i, img in TQDM(zip(indices, files), total=n):
|
| 645 |
+
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
| 646 |
+
with open(path.parent / txt[i], "a") as f:
|
| 647 |
+
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
|
yolov8_model/ultralytics/engine/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
yolov8_model/ultralytics/engine/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
yolov8_model/ultralytics/engine/__pycache__/exporter.cpython-310.pyc
ADDED
|
Binary file (38.5 kB). View file
|
|
|
yolov8_model/ultralytics/engine/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (35.1 kB). View file
|
|
|
yolov8_model/ultralytics/engine/__pycache__/predictor.cpython-310.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
yolov8_model/ultralytics/engine/__pycache__/results.cpython-310.pyc
ADDED
|
Binary file (27.3 kB). View file
|
|
|
yolov8_model/ultralytics/engine/__pycache__/trainer.cpython-310.pyc
ADDED
|
Binary file (26.3 kB). View file
|
|
|
yolov8_model/ultralytics/engine/__pycache__/validator.cpython-310.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
yolov8_model/ultralytics/engine/exporter.py
ADDED
|
@@ -0,0 +1,1099 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
"""
|
| 3 |
+
Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
|
| 4 |
+
|
| 5 |
+
Format | `format=argument` | Model
|
| 6 |
+
--- | --- | ---
|
| 7 |
+
PyTorch | - | yolov8n.pt
|
| 8 |
+
TorchScript | `torchscript` | yolov8n.torchscript
|
| 9 |
+
ONNX | `onnx` | yolov8n.onnx
|
| 10 |
+
OpenVINO | `openvino` | yolov8n_openvino_model/
|
| 11 |
+
TensorRT | `engine` | yolov8n.engine
|
| 12 |
+
CoreML | `coreml` | yolov8n.mlpackage
|
| 13 |
+
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
|
| 14 |
+
TensorFlow GraphDef | `pb` | yolov8n.pb
|
| 15 |
+
TensorFlow Lite | `tflite` | yolov8n.tflite
|
| 16 |
+
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
| 17 |
+
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
| 18 |
+
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
| 19 |
+
ncnn | `ncnn` | yolov8n_ncnn_model/
|
| 20 |
+
|
| 21 |
+
Requirements:
|
| 22 |
+
$ pip install "ultralytics[export]"
|
| 23 |
+
|
| 24 |
+
Python:
|
| 25 |
+
from ultralytics import YOLO
|
| 26 |
+
model = YOLO('yolov8n.pt')
|
| 27 |
+
results = model.export(format='onnx')
|
| 28 |
+
|
| 29 |
+
CLI:
|
| 30 |
+
$ yolo mode=export model=yolov8n.pt format=onnx
|
| 31 |
+
|
| 32 |
+
Inference:
|
| 33 |
+
$ yolo predict model=yolov8n.pt # PyTorch
|
| 34 |
+
yolov8n.torchscript # TorchScript
|
| 35 |
+
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
| 36 |
+
yolov8n_openvino_model # OpenVINO
|
| 37 |
+
yolov8n.engine # TensorRT
|
| 38 |
+
yolov8n.mlpackage # CoreML (macOS-only)
|
| 39 |
+
yolov8n_saved_model # TensorFlow SavedModel
|
| 40 |
+
yolov8n.pb # TensorFlow GraphDef
|
| 41 |
+
yolov8n.tflite # TensorFlow Lite
|
| 42 |
+
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
| 43 |
+
yolov8n_paddle_model # PaddlePaddle
|
| 44 |
+
|
| 45 |
+
TensorFlow.js:
|
| 46 |
+
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
| 47 |
+
$ npm install
|
| 48 |
+
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
|
| 49 |
+
$ npm start
|
| 50 |
+
"""
|
| 51 |
+
import json
|
| 52 |
+
import os
|
| 53 |
+
import shutil
|
| 54 |
+
import subprocess
|
| 55 |
+
import time
|
| 56 |
+
import warnings
|
| 57 |
+
from copy import deepcopy
|
| 58 |
+
from datetime import datetime
|
| 59 |
+
from pathlib import Path
|
| 60 |
+
|
| 61 |
+
import numpy as np
|
| 62 |
+
import torch
|
| 63 |
+
|
| 64 |
+
from yolov8_model.ultralytics.cfg import get_cfg
|
| 65 |
+
from yolov8_model.ultralytics.data.dataset import YOLODataset
|
| 66 |
+
from yolov8_model.ultralytics.data.utils import check_det_dataset
|
| 67 |
+
from yolov8_model.ultralytics.nn.autobackend import check_class_names, default_class_names
|
| 68 |
+
from yolov8_model.ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
| 69 |
+
from yolov8_model.ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
| 70 |
+
from yolov8_model.ultralytics.utils import (
|
| 71 |
+
ARM64,
|
| 72 |
+
DEFAULT_CFG,
|
| 73 |
+
LINUX,
|
| 74 |
+
LOGGER,
|
| 75 |
+
MACOS,
|
| 76 |
+
ROOT,
|
| 77 |
+
WINDOWS,
|
| 78 |
+
__version__,
|
| 79 |
+
callbacks,
|
| 80 |
+
colorstr,
|
| 81 |
+
get_default_args,
|
| 82 |
+
yaml_save,
|
| 83 |
+
)
|
| 84 |
+
from yolov8_model.ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
|
| 85 |
+
from yolov8_model.ultralytics.utils.downloads import attempt_download_asset, get_github_assets
|
| 86 |
+
from yolov8_model.ultralytics.utils.files import file_size, spaces_in_path
|
| 87 |
+
from yolov8_model.ultralytics.utils.ops import Profile
|
| 88 |
+
from yolov8_model.ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def export_formats():
|
| 92 |
+
"""YOLOv8 export formats."""
|
| 93 |
+
import pandas
|
| 94 |
+
|
| 95 |
+
x = [
|
| 96 |
+
["PyTorch", "-", ".pt", True, True],
|
| 97 |
+
["TorchScript", "torchscript", ".torchscript", True, True],
|
| 98 |
+
["ONNX", "onnx", ".onnx", True, True],
|
| 99 |
+
["OpenVINO", "openvino", "_openvino_model", True, False],
|
| 100 |
+
["TensorRT", "engine", ".engine", False, True],
|
| 101 |
+
["CoreML", "coreml", ".mlpackage", True, False],
|
| 102 |
+
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
|
| 103 |
+
["TensorFlow GraphDef", "pb", ".pb", True, True],
|
| 104 |
+
["TensorFlow Lite", "tflite", ".tflite", True, False],
|
| 105 |
+
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
|
| 106 |
+
["TensorFlow.js", "tfjs", "_web_model", True, False],
|
| 107 |
+
["PaddlePaddle", "paddle", "_paddle_model", True, True],
|
| 108 |
+
["ncnn", "ncnn", "_ncnn_model", True, True],
|
| 109 |
+
]
|
| 110 |
+
return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def gd_outputs(gd):
|
| 114 |
+
"""TensorFlow GraphDef model output node names."""
|
| 115 |
+
name_list, input_list = [], []
|
| 116 |
+
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
| 117 |
+
name_list.append(node.name)
|
| 118 |
+
input_list.extend(node.input)
|
| 119 |
+
return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def try_export(inner_func):
|
| 123 |
+
"""YOLOv8 export decorator, i..e @try_export."""
|
| 124 |
+
inner_args = get_default_args(inner_func)
|
| 125 |
+
|
| 126 |
+
def outer_func(*args, **kwargs):
|
| 127 |
+
"""Export a model."""
|
| 128 |
+
prefix = inner_args["prefix"]
|
| 129 |
+
try:
|
| 130 |
+
with Profile() as dt:
|
| 131 |
+
f, model = inner_func(*args, **kwargs)
|
| 132 |
+
LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
|
| 133 |
+
return f, model
|
| 134 |
+
except Exception as e:
|
| 135 |
+
LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
|
| 136 |
+
raise e
|
| 137 |
+
|
| 138 |
+
return outer_func
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class Exporter:
|
| 142 |
+
"""
|
| 143 |
+
A class for exporting a model.
|
| 144 |
+
|
| 145 |
+
Attributes:
|
| 146 |
+
args (SimpleNamespace): Configuration for the exporter.
|
| 147 |
+
callbacks (list, optional): List of callback functions. Defaults to None.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
| 151 |
+
"""
|
| 152 |
+
Initializes the Exporter class.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
| 156 |
+
overrides (dict, optional): Configuration overrides. Defaults to None.
|
| 157 |
+
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
|
| 158 |
+
"""
|
| 159 |
+
self.args = get_cfg(cfg, overrides)
|
| 160 |
+
if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
|
| 161 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
|
| 162 |
+
|
| 163 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
| 164 |
+
callbacks.add_integration_callbacks(self)
|
| 165 |
+
|
| 166 |
+
@smart_inference_mode()
|
| 167 |
+
def __call__(self, model=None):
|
| 168 |
+
"""Returns list of exported files/dirs after running callbacks."""
|
| 169 |
+
self.run_callbacks("on_export_start")
|
| 170 |
+
t = time.time()
|
| 171 |
+
fmt = self.args.format.lower() # to lowercase
|
| 172 |
+
if fmt in ("tensorrt", "trt"): # 'engine' aliases
|
| 173 |
+
fmt = "engine"
|
| 174 |
+
if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
|
| 175 |
+
fmt = "coreml"
|
| 176 |
+
fmts = tuple(export_formats()["Argument"][1:]) # available export formats
|
| 177 |
+
flags = [x == fmt for x in fmts]
|
| 178 |
+
if sum(flags) != 1:
|
| 179 |
+
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
|
| 180 |
+
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
|
| 181 |
+
|
| 182 |
+
# Device
|
| 183 |
+
if fmt == "engine" and self.args.device is None:
|
| 184 |
+
LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
|
| 185 |
+
self.args.device = "0"
|
| 186 |
+
self.device = select_device("cpu" if self.args.device is None else self.args.device)
|
| 187 |
+
|
| 188 |
+
# Checks
|
| 189 |
+
if not hasattr(model, "names"):
|
| 190 |
+
model.names = default_class_names()
|
| 191 |
+
model.names = check_class_names(model.names)
|
| 192 |
+
if self.args.half and onnx and self.device.type == "cpu":
|
| 193 |
+
LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0")
|
| 194 |
+
self.args.half = False
|
| 195 |
+
assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
|
| 196 |
+
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
| 197 |
+
if self.args.optimize:
|
| 198 |
+
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
| 199 |
+
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
| 200 |
+
if edgetpu and not LINUX:
|
| 201 |
+
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
|
| 202 |
+
|
| 203 |
+
# Input
|
| 204 |
+
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
| 205 |
+
file = Path(
|
| 206 |
+
getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
|
| 207 |
+
)
|
| 208 |
+
if file.suffix in {".yaml", ".yml"}:
|
| 209 |
+
file = Path(file.name)
|
| 210 |
+
|
| 211 |
+
# Update model
|
| 212 |
+
model = deepcopy(model).to(self.device)
|
| 213 |
+
for p in model.parameters():
|
| 214 |
+
p.requires_grad = False
|
| 215 |
+
model.eval()
|
| 216 |
+
model.float()
|
| 217 |
+
model = model.fuse()
|
| 218 |
+
for m in model.modules():
|
| 219 |
+
if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
|
| 220 |
+
m.dynamic = self.args.dynamic
|
| 221 |
+
m.export = True
|
| 222 |
+
m.format = self.args.format
|
| 223 |
+
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
|
| 224 |
+
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
| 225 |
+
m.forward = m.forward_split
|
| 226 |
+
|
| 227 |
+
y = None
|
| 228 |
+
for _ in range(2):
|
| 229 |
+
y = model(im) # dry runs
|
| 230 |
+
if self.args.half and onnx and self.device.type != "cpu":
|
| 231 |
+
im, model = im.half(), model.half() # to FP16
|
| 232 |
+
|
| 233 |
+
# Filter warnings
|
| 234 |
+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
|
| 235 |
+
warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
|
| 236 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
|
| 237 |
+
|
| 238 |
+
# Assign
|
| 239 |
+
self.im = im
|
| 240 |
+
self.model = model
|
| 241 |
+
self.file = file
|
| 242 |
+
self.output_shape = (
|
| 243 |
+
tuple(y.shape)
|
| 244 |
+
if isinstance(y, torch.Tensor)
|
| 245 |
+
else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
|
| 246 |
+
)
|
| 247 |
+
self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
|
| 248 |
+
data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
|
| 249 |
+
description = f'Ultralytics {self.pretty_name} model {f"trained on {data}" if data else ""}'
|
| 250 |
+
self.metadata = {
|
| 251 |
+
"description": description,
|
| 252 |
+
"author": "Ultralytics",
|
| 253 |
+
"license": "AGPL-3.0 https://ultralytics.com/license",
|
| 254 |
+
"date": datetime.now().isoformat(),
|
| 255 |
+
"version": __version__,
|
| 256 |
+
"stride": int(max(model.stride)),
|
| 257 |
+
"task": model.task,
|
| 258 |
+
"batch": self.args.batch,
|
| 259 |
+
"imgsz": self.imgsz,
|
| 260 |
+
"names": model.names,
|
| 261 |
+
} # model metadata
|
| 262 |
+
if model.task == "pose":
|
| 263 |
+
self.metadata["kpt_shape"] = model.model[-1].kpt_shape
|
| 264 |
+
|
| 265 |
+
LOGGER.info(
|
| 266 |
+
f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
|
| 267 |
+
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)'
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Exports
|
| 271 |
+
f = [""] * len(fmts) # exported filenames
|
| 272 |
+
if jit or ncnn: # TorchScript
|
| 273 |
+
f[0], _ = self.export_torchscript()
|
| 274 |
+
if engine: # TensorRT required before ONNX
|
| 275 |
+
f[1], _ = self.export_engine()
|
| 276 |
+
if onnx or xml: # OpenVINO requires ONNX
|
| 277 |
+
f[2], _ = self.export_onnx()
|
| 278 |
+
if xml: # OpenVINO
|
| 279 |
+
f[3], _ = self.export_openvino()
|
| 280 |
+
if coreml: # CoreML
|
| 281 |
+
f[4], _ = self.export_coreml()
|
| 282 |
+
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
| 283 |
+
self.args.int8 |= edgetpu
|
| 284 |
+
f[5], keras_model = self.export_saved_model()
|
| 285 |
+
if pb or tfjs: # pb prerequisite to tfjs
|
| 286 |
+
f[6], _ = self.export_pb(keras_model=keras_model)
|
| 287 |
+
if tflite:
|
| 288 |
+
f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
| 289 |
+
if edgetpu:
|
| 290 |
+
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
|
| 291 |
+
if tfjs:
|
| 292 |
+
f[9], _ = self.export_tfjs()
|
| 293 |
+
if paddle: # PaddlePaddle
|
| 294 |
+
f[10], _ = self.export_paddle()
|
| 295 |
+
if ncnn: # ncnn
|
| 296 |
+
f[11], _ = self.export_ncnn()
|
| 297 |
+
|
| 298 |
+
# Finish
|
| 299 |
+
f = [str(x) for x in f if x] # filter out '' and None
|
| 300 |
+
if any(f):
|
| 301 |
+
f = str(Path(f[-1]))
|
| 302 |
+
square = self.imgsz[0] == self.imgsz[1]
|
| 303 |
+
s = (
|
| 304 |
+
""
|
| 305 |
+
if square
|
| 306 |
+
else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
|
| 307 |
+
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
| 308 |
+
)
|
| 309 |
+
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
|
| 310 |
+
predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
|
| 311 |
+
q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
|
| 312 |
+
LOGGER.info(
|
| 313 |
+
f'\nExport complete ({time.time() - t:.1f}s)'
|
| 314 |
+
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
| 315 |
+
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}'
|
| 316 |
+
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}'
|
| 317 |
+
f'\nVisualize: https://netron.app'
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.run_callbacks("on_export_end")
|
| 321 |
+
return f # return list of exported files/dirs
|
| 322 |
+
|
| 323 |
+
@try_export
|
| 324 |
+
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
| 325 |
+
"""YOLOv8 TorchScript model export."""
|
| 326 |
+
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
| 327 |
+
f = self.file.with_suffix(".torchscript")
|
| 328 |
+
|
| 329 |
+
ts = torch.jit.trace(self.model, self.im, strict=False)
|
| 330 |
+
extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
| 331 |
+
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
| 332 |
+
LOGGER.info(f"{prefix} optimizing for mobile...")
|
| 333 |
+
from torch.utils.mobile_optimizer import optimize_for_mobile
|
| 334 |
+
|
| 335 |
+
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
| 336 |
+
else:
|
| 337 |
+
ts.save(str(f), _extra_files=extra_files)
|
| 338 |
+
return f, None
|
| 339 |
+
|
| 340 |
+
@try_export
|
| 341 |
+
def export_onnx(self, prefix=colorstr("ONNX:")):
|
| 342 |
+
"""YOLOv8 ONNX export."""
|
| 343 |
+
requirements = ["onnx>=1.12.0"]
|
| 344 |
+
if self.args.simplify:
|
| 345 |
+
requirements += ["onnxsim>=0.4.33", "onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime"]
|
| 346 |
+
check_requirements(requirements)
|
| 347 |
+
import onnx # noqa
|
| 348 |
+
|
| 349 |
+
opset_version = self.args.opset or get_latest_opset()
|
| 350 |
+
LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
|
| 351 |
+
f = str(self.file.with_suffix(".onnx"))
|
| 352 |
+
|
| 353 |
+
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
|
| 354 |
+
dynamic = self.args.dynamic
|
| 355 |
+
if dynamic:
|
| 356 |
+
dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
|
| 357 |
+
if isinstance(self.model, SegmentationModel):
|
| 358 |
+
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
|
| 359 |
+
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
|
| 360 |
+
elif isinstance(self.model, DetectionModel):
|
| 361 |
+
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
|
| 362 |
+
|
| 363 |
+
torch.onnx.export(
|
| 364 |
+
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
|
| 365 |
+
self.im.cpu() if dynamic else self.im,
|
| 366 |
+
f,
|
| 367 |
+
verbose=False,
|
| 368 |
+
opset_version=opset_version,
|
| 369 |
+
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
| 370 |
+
input_names=["images"],
|
| 371 |
+
output_names=output_names,
|
| 372 |
+
dynamic_axes=dynamic or None,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# Checks
|
| 376 |
+
model_onnx = onnx.load(f) # load onnx model
|
| 377 |
+
# onnx.checker.check_model(model_onnx) # check onnx model
|
| 378 |
+
|
| 379 |
+
# Simplify
|
| 380 |
+
if self.args.simplify:
|
| 381 |
+
try:
|
| 382 |
+
import onnxsim
|
| 383 |
+
|
| 384 |
+
LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...")
|
| 385 |
+
# subprocess.run(f'onnxsim "{f}" "{f}"', shell=True)
|
| 386 |
+
model_onnx, check = onnxsim.simplify(model_onnx)
|
| 387 |
+
assert check, "Simplified ONNX model could not be validated"
|
| 388 |
+
except Exception as e:
|
| 389 |
+
LOGGER.info(f"{prefix} simplifier failure: {e}")
|
| 390 |
+
|
| 391 |
+
# Metadata
|
| 392 |
+
for k, v in self.metadata.items():
|
| 393 |
+
meta = model_onnx.metadata_props.add()
|
| 394 |
+
meta.key, meta.value = k, str(v)
|
| 395 |
+
|
| 396 |
+
onnx.save(model_onnx, f)
|
| 397 |
+
return f, model_onnx
|
| 398 |
+
|
| 399 |
+
@try_export
|
| 400 |
+
def export_openvino(self, prefix=colorstr("OpenVINO:")):
|
| 401 |
+
"""YOLOv8 OpenVINO export."""
|
| 402 |
+
check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
| 403 |
+
import openvino.runtime as ov # noqa
|
| 404 |
+
from openvino.tools import mo # noqa
|
| 405 |
+
|
| 406 |
+
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
| 407 |
+
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
|
| 408 |
+
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
|
| 409 |
+
f_onnx = self.file.with_suffix(".onnx")
|
| 410 |
+
f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
|
| 411 |
+
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
|
| 412 |
+
|
| 413 |
+
def serialize(ov_model, file):
|
| 414 |
+
"""Set RT info, serialize and save metadata YAML."""
|
| 415 |
+
ov_model.set_rt_info("YOLOv8", ["model_info", "model_type"])
|
| 416 |
+
ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
|
| 417 |
+
ov_model.set_rt_info(114, ["model_info", "pad_value"])
|
| 418 |
+
ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
|
| 419 |
+
ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
|
| 420 |
+
ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
|
| 421 |
+
if self.model.task != "classify":
|
| 422 |
+
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
|
| 423 |
+
|
| 424 |
+
ov.serialize(ov_model, file) # save
|
| 425 |
+
yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
|
| 426 |
+
|
| 427 |
+
ov_model = mo.convert_model(
|
| 428 |
+
f_onnx, model_name=self.pretty_name, framework="onnx", compress_to_fp16=self.args.half
|
| 429 |
+
) # export
|
| 430 |
+
|
| 431 |
+
if self.args.int8:
|
| 432 |
+
if not self.args.data:
|
| 433 |
+
self.args.data = DEFAULT_CFG.data or "coco128.yaml"
|
| 434 |
+
LOGGER.warning(
|
| 435 |
+
f"{prefix} WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
|
| 436 |
+
f"Using default 'data={self.args.data}'."
|
| 437 |
+
)
|
| 438 |
+
check_requirements("nncf>=2.5.0")
|
| 439 |
+
import nncf
|
| 440 |
+
|
| 441 |
+
def transform_fn(data_item):
|
| 442 |
+
"""Quantization transform function."""
|
| 443 |
+
assert (
|
| 444 |
+
data_item["img"].dtype == torch.uint8
|
| 445 |
+
), "Input image must be uint8 for the quantization preprocessing"
|
| 446 |
+
im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
|
| 447 |
+
return np.expand_dims(im, 0) if im.ndim == 3 else im
|
| 448 |
+
|
| 449 |
+
# Generate calibration data for integer quantization
|
| 450 |
+
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
| 451 |
+
data = check_det_dataset(self.args.data)
|
| 452 |
+
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
|
| 453 |
+
n = len(dataset)
|
| 454 |
+
if n < 300:
|
| 455 |
+
LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
|
| 456 |
+
quantization_dataset = nncf.Dataset(dataset, transform_fn)
|
| 457 |
+
ignored_scope = nncf.IgnoredScope(types=["Multiply", "Subtract", "Sigmoid"]) # ignore operation
|
| 458 |
+
quantized_ov_model = nncf.quantize(
|
| 459 |
+
ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope
|
| 460 |
+
)
|
| 461 |
+
serialize(quantized_ov_model, fq_ov)
|
| 462 |
+
return fq, None
|
| 463 |
+
|
| 464 |
+
serialize(ov_model, f_ov)
|
| 465 |
+
return f, None
|
| 466 |
+
|
| 467 |
+
@try_export
|
| 468 |
+
def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
|
| 469 |
+
"""YOLOv8 Paddle export."""
|
| 470 |
+
check_requirements(("paddlepaddle", "x2paddle"))
|
| 471 |
+
import x2paddle # noqa
|
| 472 |
+
from x2paddle.convert import pytorch2paddle # noqa
|
| 473 |
+
|
| 474 |
+
LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
|
| 475 |
+
f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
|
| 476 |
+
|
| 477 |
+
pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
|
| 478 |
+
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
| 479 |
+
return f, None
|
| 480 |
+
|
| 481 |
+
@try_export
|
| 482 |
+
def export_ncnn(self, prefix=colorstr("ncnn:")):
|
| 483 |
+
"""
|
| 484 |
+
YOLOv8 ncnn export using PNNX https://github.com/pnnx/pnnx.
|
| 485 |
+
"""
|
| 486 |
+
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires ncnn
|
| 487 |
+
import ncnn # noqa
|
| 488 |
+
|
| 489 |
+
LOGGER.info(f"\n{prefix} starting export with ncnn {ncnn.__version__}...")
|
| 490 |
+
f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
|
| 491 |
+
f_ts = self.file.with_suffix(".torchscript")
|
| 492 |
+
|
| 493 |
+
name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
|
| 494 |
+
pnnx = name if name.is_file() else ROOT / name
|
| 495 |
+
if not pnnx.is_file():
|
| 496 |
+
LOGGER.warning(
|
| 497 |
+
f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from "
|
| 498 |
+
"https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
|
| 499 |
+
f"or in {ROOT}. See PNNX repo for full installation instructions."
|
| 500 |
+
)
|
| 501 |
+
system = ["macos"] if MACOS else ["windows"] if WINDOWS else ["ubuntu", "linux"] # operating system
|
| 502 |
+
try:
|
| 503 |
+
_, assets = get_github_assets(repo="pnnx/pnnx", retry=True)
|
| 504 |
+
url = [x for x in assets if any(s in x for s in system)][0]
|
| 505 |
+
except Exception as e:
|
| 506 |
+
url = f"https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip"
|
| 507 |
+
LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}")
|
| 508 |
+
asset = attempt_download_asset(url, repo="pnnx/pnnx", release="latest")
|
| 509 |
+
if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability
|
| 510 |
+
unzip_dir = Path(asset).with_suffix("")
|
| 511 |
+
(unzip_dir / name).rename(pnnx) # move binary to ROOT
|
| 512 |
+
shutil.rmtree(unzip_dir) # delete unzip dir
|
| 513 |
+
Path(asset).unlink() # delete zip
|
| 514 |
+
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
| 515 |
+
|
| 516 |
+
ncnn_args = [
|
| 517 |
+
f'ncnnparam={f / "model.ncnn.param"}',
|
| 518 |
+
f'ncnnbin={f / "model.ncnn.bin"}',
|
| 519 |
+
f'ncnnpy={f / "model_ncnn.py"}',
|
| 520 |
+
]
|
| 521 |
+
|
| 522 |
+
pnnx_args = [
|
| 523 |
+
f'pnnxparam={f / "model.pnnx.param"}',
|
| 524 |
+
f'pnnxbin={f / "model.pnnx.bin"}',
|
| 525 |
+
f'pnnxpy={f / "model_pnnx.py"}',
|
| 526 |
+
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
| 527 |
+
]
|
| 528 |
+
|
| 529 |
+
cmd = [
|
| 530 |
+
str(pnnx),
|
| 531 |
+
str(f_ts),
|
| 532 |
+
*ncnn_args,
|
| 533 |
+
*pnnx_args,
|
| 534 |
+
f"fp16={int(self.args.half)}",
|
| 535 |
+
f"device={self.device.type}",
|
| 536 |
+
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
|
| 537 |
+
]
|
| 538 |
+
f.mkdir(exist_ok=True) # make ncnn_model directory
|
| 539 |
+
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
| 540 |
+
subprocess.run(cmd, check=True)
|
| 541 |
+
|
| 542 |
+
# Remove debug files
|
| 543 |
+
pnnx_files = [x.split("=")[-1] for x in pnnx_args]
|
| 544 |
+
for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
|
| 545 |
+
Path(f_debug).unlink(missing_ok=True)
|
| 546 |
+
|
| 547 |
+
yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
| 548 |
+
return str(f), None
|
| 549 |
+
|
| 550 |
+
@try_export
|
| 551 |
+
def export_coreml(self, prefix=colorstr("CoreML:")):
|
| 552 |
+
"""YOLOv8 CoreML export."""
|
| 553 |
+
mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
|
| 554 |
+
check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0")
|
| 555 |
+
import coremltools as ct # noqa
|
| 556 |
+
|
| 557 |
+
LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
|
| 558 |
+
assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux."
|
| 559 |
+
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
|
| 560 |
+
if f.is_dir():
|
| 561 |
+
shutil.rmtree(f)
|
| 562 |
+
|
| 563 |
+
bias = [0.0, 0.0, 0.0]
|
| 564 |
+
scale = 1 / 255
|
| 565 |
+
classifier_config = None
|
| 566 |
+
if self.model.task == "classify":
|
| 567 |
+
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
|
| 568 |
+
model = self.model
|
| 569 |
+
elif self.model.task == "detect":
|
| 570 |
+
model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
|
| 571 |
+
else:
|
| 572 |
+
if self.args.nms:
|
| 573 |
+
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolov8n.pt'.")
|
| 574 |
+
# TODO CoreML Segment and Pose model pipelining
|
| 575 |
+
model = self.model
|
| 576 |
+
|
| 577 |
+
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
| 578 |
+
ct_model = ct.convert(
|
| 579 |
+
ts,
|
| 580 |
+
inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
|
| 581 |
+
classifier_config=classifier_config,
|
| 582 |
+
convert_to="neuralnetwork" if mlmodel else "mlprogram",
|
| 583 |
+
)
|
| 584 |
+
bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
|
| 585 |
+
if bits < 32:
|
| 586 |
+
if "kmeans" in mode:
|
| 587 |
+
check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
|
| 588 |
+
if mlmodel:
|
| 589 |
+
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
| 590 |
+
elif bits == 8: # mlprogram already quantized to FP16
|
| 591 |
+
import coremltools.optimize.coreml as cto
|
| 592 |
+
|
| 593 |
+
op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
|
| 594 |
+
config = cto.OptimizationConfig(global_config=op_config)
|
| 595 |
+
ct_model = cto.palettize_weights(ct_model, config=config)
|
| 596 |
+
if self.args.nms and self.model.task == "detect":
|
| 597 |
+
if mlmodel:
|
| 598 |
+
import platform
|
| 599 |
+
|
| 600 |
+
# coremltools<=6.2 NMS export requires Python<3.11
|
| 601 |
+
check_version(platform.python_version(), "<3.11", name="Python ", hard=True)
|
| 602 |
+
weights_dir = None
|
| 603 |
+
else:
|
| 604 |
+
ct_model.save(str(f)) # save otherwise weights_dir does not exist
|
| 605 |
+
weights_dir = str(f / "Data/com.apple.CoreML/weights")
|
| 606 |
+
ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
|
| 607 |
+
|
| 608 |
+
m = self.metadata # metadata dict
|
| 609 |
+
ct_model.short_description = m.pop("description")
|
| 610 |
+
ct_model.author = m.pop("author")
|
| 611 |
+
ct_model.license = m.pop("license")
|
| 612 |
+
ct_model.version = m.pop("version")
|
| 613 |
+
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
| 614 |
+
try:
|
| 615 |
+
ct_model.save(str(f)) # save *.mlpackage
|
| 616 |
+
except Exception as e:
|
| 617 |
+
LOGGER.warning(
|
| 618 |
+
f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
|
| 619 |
+
f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
|
| 620 |
+
)
|
| 621 |
+
f = f.with_suffix(".mlmodel")
|
| 622 |
+
ct_model.save(str(f))
|
| 623 |
+
return f, ct_model
|
| 624 |
+
|
| 625 |
+
@try_export
|
| 626 |
+
def export_engine(self, prefix=colorstr("TensorRT:")):
|
| 627 |
+
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
|
| 628 |
+
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
| 629 |
+
f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016
|
| 630 |
+
|
| 631 |
+
try:
|
| 632 |
+
import tensorrt as trt # noqa
|
| 633 |
+
except ImportError:
|
| 634 |
+
if LINUX:
|
| 635 |
+
check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
|
| 636 |
+
import tensorrt as trt # noqa
|
| 637 |
+
|
| 638 |
+
check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
|
| 639 |
+
|
| 640 |
+
self.args.simplify = True
|
| 641 |
+
|
| 642 |
+
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
|
| 643 |
+
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
| 644 |
+
f = self.file.with_suffix(".engine") # TensorRT engine file
|
| 645 |
+
logger = trt.Logger(trt.Logger.INFO)
|
| 646 |
+
if self.args.verbose:
|
| 647 |
+
logger.min_severity = trt.Logger.Severity.VERBOSE
|
| 648 |
+
|
| 649 |
+
builder = trt.Builder(logger)
|
| 650 |
+
config = builder.create_builder_config()
|
| 651 |
+
config.max_workspace_size = self.args.workspace * 1 << 30
|
| 652 |
+
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
| 653 |
+
|
| 654 |
+
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
| 655 |
+
network = builder.create_network(flag)
|
| 656 |
+
parser = trt.OnnxParser(network, logger)
|
| 657 |
+
if not parser.parse_from_file(f_onnx):
|
| 658 |
+
raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
|
| 659 |
+
|
| 660 |
+
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
| 661 |
+
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
| 662 |
+
for inp in inputs:
|
| 663 |
+
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
|
| 664 |
+
for out in outputs:
|
| 665 |
+
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
|
| 666 |
+
|
| 667 |
+
if self.args.dynamic:
|
| 668 |
+
shape = self.im.shape
|
| 669 |
+
if shape[0] <= 1:
|
| 670 |
+
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
|
| 671 |
+
profile = builder.create_optimization_profile()
|
| 672 |
+
for inp in inputs:
|
| 673 |
+
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
| 674 |
+
config.add_optimization_profile(profile)
|
| 675 |
+
|
| 676 |
+
LOGGER.info(
|
| 677 |
+
f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
|
| 678 |
+
)
|
| 679 |
+
if builder.platform_has_fast_fp16 and self.args.half:
|
| 680 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
| 681 |
+
|
| 682 |
+
del self.model
|
| 683 |
+
torch.cuda.empty_cache()
|
| 684 |
+
|
| 685 |
+
# Write file
|
| 686 |
+
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
|
| 687 |
+
# Metadata
|
| 688 |
+
meta = json.dumps(self.metadata)
|
| 689 |
+
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
| 690 |
+
t.write(meta.encode())
|
| 691 |
+
# Model
|
| 692 |
+
t.write(engine.serialize())
|
| 693 |
+
|
| 694 |
+
return f, None
|
| 695 |
+
|
| 696 |
+
@try_export
|
| 697 |
+
def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
|
| 698 |
+
"""YOLOv8 TensorFlow SavedModel export."""
|
| 699 |
+
cuda = torch.cuda.is_available()
|
| 700 |
+
try:
|
| 701 |
+
import tensorflow as tf # noqa
|
| 702 |
+
except ImportError:
|
| 703 |
+
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
| 704 |
+
import tensorflow as tf # noqa
|
| 705 |
+
check_requirements(
|
| 706 |
+
(
|
| 707 |
+
"onnx",
|
| 708 |
+
"onnx2tf>=1.15.4,<=1.17.5",
|
| 709 |
+
"sng4onnx>=1.0.1",
|
| 710 |
+
"onnxsim>=0.4.33",
|
| 711 |
+
"onnx_graphsurgeon>=0.3.26",
|
| 712 |
+
"tflite_support",
|
| 713 |
+
"onnxruntime-gpu" if cuda else "onnxruntime",
|
| 714 |
+
),
|
| 715 |
+
cmds="--extra-index-url https://pypi.ngc.nvidia.com",
|
| 716 |
+
) # onnx_graphsurgeon only on NVIDIA
|
| 717 |
+
|
| 718 |
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
| 719 |
+
check_version(
|
| 720 |
+
tf.__version__,
|
| 721 |
+
"<=2.13.1",
|
| 722 |
+
name="tensorflow",
|
| 723 |
+
verbose=True,
|
| 724 |
+
msg="https://github.com/ultralytics/ultralytics/issues/5161",
|
| 725 |
+
)
|
| 726 |
+
f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
|
| 727 |
+
if f.is_dir():
|
| 728 |
+
import shutil
|
| 729 |
+
|
| 730 |
+
shutil.rmtree(f) # delete output folder
|
| 731 |
+
|
| 732 |
+
# Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
|
| 733 |
+
onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
|
| 734 |
+
if not onnx2tf_file.exists():
|
| 735 |
+
attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
|
| 736 |
+
|
| 737 |
+
# Export to ONNX
|
| 738 |
+
self.args.simplify = True
|
| 739 |
+
f_onnx, _ = self.export_onnx()
|
| 740 |
+
|
| 741 |
+
# Export to TF
|
| 742 |
+
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
| 743 |
+
if self.args.int8:
|
| 744 |
+
verbosity = "--verbosity info"
|
| 745 |
+
if self.args.data:
|
| 746 |
+
# Generate calibration data for integer quantization
|
| 747 |
+
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
| 748 |
+
data = check_det_dataset(self.args.data)
|
| 749 |
+
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
|
| 750 |
+
images = []
|
| 751 |
+
for i, batch in enumerate(dataset):
|
| 752 |
+
if i >= 100: # maximum number of calibration images
|
| 753 |
+
break
|
| 754 |
+
im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
|
| 755 |
+
images.append(im)
|
| 756 |
+
f.mkdir()
|
| 757 |
+
images = torch.cat(images, 0).float()
|
| 758 |
+
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
|
| 759 |
+
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
|
| 760 |
+
np.save(str(tmp_file), images.numpy()) # BHWC
|
| 761 |
+
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
|
| 762 |
+
else:
|
| 763 |
+
int8 = "-oiqt -qt per-tensor"
|
| 764 |
+
else:
|
| 765 |
+
verbosity = "--non_verbose"
|
| 766 |
+
int8 = ""
|
| 767 |
+
|
| 768 |
+
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo {verbosity} {int8}'.strip()
|
| 769 |
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
| 770 |
+
subprocess.run(cmd, shell=True)
|
| 771 |
+
yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
|
| 772 |
+
|
| 773 |
+
# Remove/rename TFLite models
|
| 774 |
+
if self.args.int8:
|
| 775 |
+
tmp_file.unlink(missing_ok=True)
|
| 776 |
+
for file in f.rglob("*_dynamic_range_quant.tflite"):
|
| 777 |
+
file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
|
| 778 |
+
for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
|
| 779 |
+
file.unlink() # delete extra fp16 activation TFLite files
|
| 780 |
+
|
| 781 |
+
# Add TFLite metadata
|
| 782 |
+
for file in f.rglob("*.tflite"):
|
| 783 |
+
f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
|
| 784 |
+
|
| 785 |
+
return str(f), tf.saved_model.load(f, tags=None, options=None) # load saved_model as Keras model
|
| 786 |
+
|
| 787 |
+
@try_export
|
| 788 |
+
def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
|
| 789 |
+
"""YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
| 790 |
+
import tensorflow as tf # noqa
|
| 791 |
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
| 792 |
+
|
| 793 |
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
| 794 |
+
f = self.file.with_suffix(".pb")
|
| 795 |
+
|
| 796 |
+
m = tf.function(lambda x: keras_model(x)) # full model
|
| 797 |
+
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
| 798 |
+
frozen_func = convert_variables_to_constants_v2(m)
|
| 799 |
+
frozen_func.graph.as_graph_def()
|
| 800 |
+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
|
| 801 |
+
return f, None
|
| 802 |
+
|
| 803 |
+
@try_export
|
| 804 |
+
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
|
| 805 |
+
"""YOLOv8 TensorFlow Lite export."""
|
| 806 |
+
import tensorflow as tf # noqa
|
| 807 |
+
|
| 808 |
+
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
| 809 |
+
saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
|
| 810 |
+
if self.args.int8:
|
| 811 |
+
f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
|
| 812 |
+
elif self.args.half:
|
| 813 |
+
f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
|
| 814 |
+
else:
|
| 815 |
+
f = saved_model / f"{self.file.stem}_float32.tflite"
|
| 816 |
+
return str(f), None
|
| 817 |
+
|
| 818 |
+
@try_export
|
| 819 |
+
def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
|
| 820 |
+
"""YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
| 821 |
+
LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")
|
| 822 |
+
|
| 823 |
+
cmd = "edgetpu_compiler --version"
|
| 824 |
+
help_url = "https://coral.ai/docs/edgetpu/compiler/"
|
| 825 |
+
assert LINUX, f"export only supported on Linux. See {help_url}"
|
| 826 |
+
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
|
| 827 |
+
LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
|
| 828 |
+
sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0 # sudo installed on system
|
| 829 |
+
for c in (
|
| 830 |
+
"curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
|
| 831 |
+
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
|
| 832 |
+
"sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
|
| 833 |
+
"sudo apt-get update",
|
| 834 |
+
"sudo apt-get install edgetpu-compiler",
|
| 835 |
+
):
|
| 836 |
+
subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True)
|
| 837 |
+
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
| 838 |
+
|
| 839 |
+
LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
|
| 840 |
+
f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
|
| 841 |
+
|
| 842 |
+
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
|
| 843 |
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
| 844 |
+
subprocess.run(cmd, shell=True)
|
| 845 |
+
self._add_tflite_metadata(f)
|
| 846 |
+
return f, None
|
| 847 |
+
|
| 848 |
+
@try_export
|
| 849 |
+
def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
|
| 850 |
+
"""YOLOv8 TensorFlow.js export."""
|
| 851 |
+
# JAX bug requiring install constraints in https://github.com/google/jax/issues/18978
|
| 852 |
+
check_requirements(["jax<=0.4.21", "jaxlib<=0.4.21", "tensorflowjs"])
|
| 853 |
+
import tensorflow as tf
|
| 854 |
+
import tensorflowjs as tfjs # noqa
|
| 855 |
+
|
| 856 |
+
LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
|
| 857 |
+
f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
|
| 858 |
+
f_pb = str(self.file.with_suffix(".pb")) # *.pb path
|
| 859 |
+
|
| 860 |
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
| 861 |
+
with open(f_pb, "rb") as file:
|
| 862 |
+
gd.ParseFromString(file.read())
|
| 863 |
+
outputs = ",".join(gd_outputs(gd))
|
| 864 |
+
LOGGER.info(f"\n{prefix} output node names: {outputs}")
|
| 865 |
+
|
| 866 |
+
quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
|
| 867 |
+
with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
|
| 868 |
+
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
|
| 869 |
+
LOGGER.info(f"{prefix} running '{cmd}'")
|
| 870 |
+
subprocess.run(cmd, shell=True)
|
| 871 |
+
|
| 872 |
+
if " " in f:
|
| 873 |
+
LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
|
| 874 |
+
|
| 875 |
+
# f_json = Path(f) / 'model.json' # *.json path
|
| 876 |
+
# with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
| 877 |
+
# subst = re.sub(
|
| 878 |
+
# r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
| 879 |
+
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
| 880 |
+
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
| 881 |
+
# r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
| 882 |
+
# r'{"outputs": {"Identity": {"name": "Identity"}, '
|
| 883 |
+
# r'"Identity_1": {"name": "Identity_1"}, '
|
| 884 |
+
# r'"Identity_2": {"name": "Identity_2"}, '
|
| 885 |
+
# r'"Identity_3": {"name": "Identity_3"}}}',
|
| 886 |
+
# f_json.read_text(),
|
| 887 |
+
# )
|
| 888 |
+
# j.write(subst)
|
| 889 |
+
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
|
| 890 |
+
return f, None
|
| 891 |
+
|
| 892 |
+
def _add_tflite_metadata(self, file):
|
| 893 |
+
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
| 894 |
+
from tflite_support import flatbuffers # noqa
|
| 895 |
+
from tflite_support import metadata as _metadata # noqa
|
| 896 |
+
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
| 897 |
+
|
| 898 |
+
# Create model info
|
| 899 |
+
model_meta = _metadata_fb.ModelMetadataT()
|
| 900 |
+
model_meta.name = self.metadata["description"]
|
| 901 |
+
model_meta.version = self.metadata["version"]
|
| 902 |
+
model_meta.author = self.metadata["author"]
|
| 903 |
+
model_meta.license = self.metadata["license"]
|
| 904 |
+
|
| 905 |
+
# Label file
|
| 906 |
+
tmp_file = Path(file).parent / "temp_meta.txt"
|
| 907 |
+
with open(tmp_file, "w") as f:
|
| 908 |
+
f.write(str(self.metadata))
|
| 909 |
+
|
| 910 |
+
label_file = _metadata_fb.AssociatedFileT()
|
| 911 |
+
label_file.name = tmp_file.name
|
| 912 |
+
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
| 913 |
+
|
| 914 |
+
# Create input info
|
| 915 |
+
input_meta = _metadata_fb.TensorMetadataT()
|
| 916 |
+
input_meta.name = "image"
|
| 917 |
+
input_meta.description = "Input image to be detected."
|
| 918 |
+
input_meta.content = _metadata_fb.ContentT()
|
| 919 |
+
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
| 920 |
+
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
| 921 |
+
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
|
| 922 |
+
|
| 923 |
+
# Create output info
|
| 924 |
+
output1 = _metadata_fb.TensorMetadataT()
|
| 925 |
+
output1.name = "output"
|
| 926 |
+
output1.description = "Coordinates of detected objects, class labels, and confidence score"
|
| 927 |
+
output1.associatedFiles = [label_file]
|
| 928 |
+
if self.model.task == "segment":
|
| 929 |
+
output2 = _metadata_fb.TensorMetadataT()
|
| 930 |
+
output2.name = "output"
|
| 931 |
+
output2.description = "Mask protos"
|
| 932 |
+
output2.associatedFiles = [label_file]
|
| 933 |
+
|
| 934 |
+
# Create subgraph info
|
| 935 |
+
subgraph = _metadata_fb.SubGraphMetadataT()
|
| 936 |
+
subgraph.inputTensorMetadata = [input_meta]
|
| 937 |
+
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
|
| 938 |
+
model_meta.subgraphMetadata = [subgraph]
|
| 939 |
+
|
| 940 |
+
b = flatbuffers.Builder(0)
|
| 941 |
+
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
| 942 |
+
metadata_buf = b.Output()
|
| 943 |
+
|
| 944 |
+
populator = _metadata.MetadataPopulator.with_model_file(str(file))
|
| 945 |
+
populator.load_metadata_buffer(metadata_buf)
|
| 946 |
+
populator.load_associated_files([str(tmp_file)])
|
| 947 |
+
populator.populate()
|
| 948 |
+
tmp_file.unlink()
|
| 949 |
+
|
| 950 |
+
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
|
| 951 |
+
"""YOLOv8 CoreML pipeline."""
|
| 952 |
+
import coremltools as ct # noqa
|
| 953 |
+
|
| 954 |
+
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
|
| 955 |
+
_, _, h, w = list(self.im.shape) # BCHW
|
| 956 |
+
|
| 957 |
+
# Output shapes
|
| 958 |
+
spec = model.get_spec()
|
| 959 |
+
out0, out1 = iter(spec.description.output)
|
| 960 |
+
if MACOS:
|
| 961 |
+
from PIL import Image
|
| 962 |
+
|
| 963 |
+
img = Image.new("RGB", (w, h)) # w=192, h=320
|
| 964 |
+
out = model.predict({"image": img})
|
| 965 |
+
out0_shape = out[out0.name].shape # (3780, 80)
|
| 966 |
+
out1_shape = out[out1.name].shape # (3780, 4)
|
| 967 |
+
else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
|
| 968 |
+
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
|
| 969 |
+
out1_shape = self.output_shape[2], 4 # (3780, 4)
|
| 970 |
+
|
| 971 |
+
# Checks
|
| 972 |
+
names = self.metadata["names"]
|
| 973 |
+
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
| 974 |
+
_, nc = out0_shape # number of anchors, number of classes
|
| 975 |
+
# _, nc = out0.type.multiArrayType.shape
|
| 976 |
+
assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
|
| 977 |
+
|
| 978 |
+
# Define output shapes (missing)
|
| 979 |
+
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
| 980 |
+
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
|
| 981 |
+
# spec.neuralNetwork.preprocessing[0].featureName = '0'
|
| 982 |
+
|
| 983 |
+
# Flexible input shapes
|
| 984 |
+
# from coremltools.models.neural_network import flexible_shape_utils
|
| 985 |
+
# s = [] # shapes
|
| 986 |
+
# s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
|
| 987 |
+
# s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
|
| 988 |
+
# flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
|
| 989 |
+
# r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
|
| 990 |
+
# r.add_height_range((192, 640))
|
| 991 |
+
# r.add_width_range((192, 640))
|
| 992 |
+
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
|
| 993 |
+
|
| 994 |
+
# Print
|
| 995 |
+
# print(spec.description)
|
| 996 |
+
|
| 997 |
+
# Model from spec
|
| 998 |
+
model = ct.models.MLModel(spec, weights_dir=weights_dir)
|
| 999 |
+
|
| 1000 |
+
# 3. Create NMS protobuf
|
| 1001 |
+
nms_spec = ct.proto.Model_pb2.Model()
|
| 1002 |
+
nms_spec.specificationVersion = 5
|
| 1003 |
+
for i in range(2):
|
| 1004 |
+
decoder_output = model._spec.description.output[i].SerializeToString()
|
| 1005 |
+
nms_spec.description.input.add()
|
| 1006 |
+
nms_spec.description.input[i].ParseFromString(decoder_output)
|
| 1007 |
+
nms_spec.description.output.add()
|
| 1008 |
+
nms_spec.description.output[i].ParseFromString(decoder_output)
|
| 1009 |
+
|
| 1010 |
+
nms_spec.description.output[0].name = "confidence"
|
| 1011 |
+
nms_spec.description.output[1].name = "coordinates"
|
| 1012 |
+
|
| 1013 |
+
output_sizes = [nc, 4]
|
| 1014 |
+
for i in range(2):
|
| 1015 |
+
ma_type = nms_spec.description.output[i].type.multiArrayType
|
| 1016 |
+
ma_type.shapeRange.sizeRanges.add()
|
| 1017 |
+
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
|
| 1018 |
+
ma_type.shapeRange.sizeRanges[0].upperBound = -1
|
| 1019 |
+
ma_type.shapeRange.sizeRanges.add()
|
| 1020 |
+
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
|
| 1021 |
+
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
|
| 1022 |
+
del ma_type.shape[:]
|
| 1023 |
+
|
| 1024 |
+
nms = nms_spec.nonMaximumSuppression
|
| 1025 |
+
nms.confidenceInputFeatureName = out0.name # 1x507x80
|
| 1026 |
+
nms.coordinatesInputFeatureName = out1.name # 1x507x4
|
| 1027 |
+
nms.confidenceOutputFeatureName = "confidence"
|
| 1028 |
+
nms.coordinatesOutputFeatureName = "coordinates"
|
| 1029 |
+
nms.iouThresholdInputFeatureName = "iouThreshold"
|
| 1030 |
+
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
|
| 1031 |
+
nms.iouThreshold = 0.45
|
| 1032 |
+
nms.confidenceThreshold = 0.25
|
| 1033 |
+
nms.pickTop.perClass = True
|
| 1034 |
+
nms.stringClassLabels.vector.extend(names.values())
|
| 1035 |
+
nms_model = ct.models.MLModel(nms_spec)
|
| 1036 |
+
|
| 1037 |
+
# 4. Pipeline models together
|
| 1038 |
+
pipeline = ct.models.pipeline.Pipeline(
|
| 1039 |
+
input_features=[
|
| 1040 |
+
("image", ct.models.datatypes.Array(3, ny, nx)),
|
| 1041 |
+
("iouThreshold", ct.models.datatypes.Double()),
|
| 1042 |
+
("confidenceThreshold", ct.models.datatypes.Double()),
|
| 1043 |
+
],
|
| 1044 |
+
output_features=["confidence", "coordinates"],
|
| 1045 |
+
)
|
| 1046 |
+
pipeline.add_model(model)
|
| 1047 |
+
pipeline.add_model(nms_model)
|
| 1048 |
+
|
| 1049 |
+
# Correct datatypes
|
| 1050 |
+
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
|
| 1051 |
+
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
|
| 1052 |
+
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
|
| 1053 |
+
|
| 1054 |
+
# Update metadata
|
| 1055 |
+
pipeline.spec.specificationVersion = 5
|
| 1056 |
+
pipeline.spec.description.metadata.userDefined.update(
|
| 1057 |
+
{"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
# Save the model
|
| 1061 |
+
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
|
| 1062 |
+
model.input_description["image"] = "Input image"
|
| 1063 |
+
model.input_description["iouThreshold"] = f"(optional) IOU threshold override (default: {nms.iouThreshold})"
|
| 1064 |
+
model.input_description[
|
| 1065 |
+
"confidenceThreshold"
|
| 1066 |
+
] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
|
| 1067 |
+
model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
|
| 1068 |
+
model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
|
| 1069 |
+
LOGGER.info(f"{prefix} pipeline success")
|
| 1070 |
+
return model
|
| 1071 |
+
|
| 1072 |
+
def add_callback(self, event: str, callback):
|
| 1073 |
+
"""Appends the given callback."""
|
| 1074 |
+
self.callbacks[event].append(callback)
|
| 1075 |
+
|
| 1076 |
+
def run_callbacks(self, event: str):
|
| 1077 |
+
"""Execute all callbacks for a given event."""
|
| 1078 |
+
for callback in self.callbacks.get(event, []):
|
| 1079 |
+
callback(self)
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
class IOSDetectModel(torch.nn.Module):
|
| 1083 |
+
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
|
| 1084 |
+
|
| 1085 |
+
def __init__(self, model, im):
|
| 1086 |
+
"""Initialize the IOSDetectModel class with a YOLO model and example image."""
|
| 1087 |
+
super().__init__()
|
| 1088 |
+
_, _, h, w = im.shape # batch, channel, height, width
|
| 1089 |
+
self.model = model
|
| 1090 |
+
self.nc = len(model.names) # number of classes
|
| 1091 |
+
if w == h:
|
| 1092 |
+
self.normalize = 1.0 / w # scalar
|
| 1093 |
+
else:
|
| 1094 |
+
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
| 1095 |
+
|
| 1096 |
+
def forward(self, x):
|
| 1097 |
+
"""Normalize predictions of object detection model with input size-dependent factors."""
|
| 1098 |
+
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
| 1099 |
+
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
yolov8_model/ultralytics/engine/model.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import inspect
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from yolov8_model.ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
| 10 |
+
from yolov8_model.ultralytics.hub.utils import HUB_WEB_ROOT
|
| 11 |
+
from yolov8_model.ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
| 12 |
+
from yolov8_model.ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Model(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
A base class for implementing YOLO models, unifying APIs across different model types.
|
| 18 |
+
|
| 19 |
+
This class provides a common interface for various operations related to YOLO models, such as training,
|
| 20 |
+
validation, prediction, exporting, and benchmarking. It handles different types of models, including those
|
| 21 |
+
loaded from local files, Ultralytics HUB, or Triton Server. The class is designed to be flexible and
|
| 22 |
+
extendable for different tasks and model configurations.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file
|
| 26 |
+
path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
|
| 27 |
+
task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's
|
| 28 |
+
application domain, such as object detection, segmentation, etc. Defaults to None.
|
| 29 |
+
verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
callbacks (dict): A dictionary of callback functions for various events during model operations.
|
| 33 |
+
predictor (BasePredictor): The predictor object used for making predictions.
|
| 34 |
+
model (nn.Module): The underlying PyTorch model.
|
| 35 |
+
trainer (BaseTrainer): The trainer object used for training the model.
|
| 36 |
+
ckpt (dict): The checkpoint data if the model is loaded from a *.pt file.
|
| 37 |
+
cfg (str): The configuration of the model if loaded from a *.yaml file.
|
| 38 |
+
ckpt_path (str): The path to the checkpoint file.
|
| 39 |
+
overrides (dict): A dictionary of overrides for model configuration.
|
| 40 |
+
metrics (dict): The latest training/validation metrics.
|
| 41 |
+
session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
|
| 42 |
+
task (str): The type of task the model is intended for.
|
| 43 |
+
model_name (str): The name of the model.
|
| 44 |
+
|
| 45 |
+
Methods:
|
| 46 |
+
__call__: Alias for the predict method, enabling the model instance to be callable.
|
| 47 |
+
_new: Initializes a new model based on a configuration file.
|
| 48 |
+
_load: Loads a model from a checkpoint file.
|
| 49 |
+
_check_is_pytorch_model: Ensures that the model is a PyTorch model.
|
| 50 |
+
reset_weights: Resets the model's weights to their initial state.
|
| 51 |
+
load: Loads model weights from a specified file.
|
| 52 |
+
save: Saves the current state of the model to a file.
|
| 53 |
+
info: Logs or returns information about the model.
|
| 54 |
+
fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference.
|
| 55 |
+
predict: Performs object detection predictions.
|
| 56 |
+
track: Performs object tracking.
|
| 57 |
+
val: Validates the model on a dataset.
|
| 58 |
+
benchmark: Benchmarks the model on various export formats.
|
| 59 |
+
export: Exports the model to different formats.
|
| 60 |
+
train: Trains the model on a dataset.
|
| 61 |
+
tune: Performs hyperparameter tuning.
|
| 62 |
+
_apply: Applies a function to the model's tensors.
|
| 63 |
+
add_callback: Adds a callback function for an event.
|
| 64 |
+
clear_callback: Clears all callbacks for an event.
|
| 65 |
+
reset_callbacks: Resets all callbacks to their default functions.
|
| 66 |
+
_get_hub_session: Retrieves or creates an Ultralytics HUB session.
|
| 67 |
+
is_triton_model: Checks if a model is a Triton Server model.
|
| 68 |
+
is_hub_model: Checks if a model is an Ultralytics HUB model.
|
| 69 |
+
_reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
|
| 70 |
+
_smart_load: Loads the appropriate module based on the model task.
|
| 71 |
+
task_map: Provides a mapping from model tasks to corresponding classes.
|
| 72 |
+
|
| 73 |
+
Raises:
|
| 74 |
+
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
| 75 |
+
ValueError: If the model file or configuration is invalid or unsupported.
|
| 76 |
+
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
| 77 |
+
TypeError: If the model is not a PyTorch model when required.
|
| 78 |
+
AttributeError: If required attributes or methods are not implemented or available.
|
| 79 |
+
NotImplementedError: If a specific model task or mode is not supported.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None:
|
| 83 |
+
"""
|
| 84 |
+
Initializes a new instance of the YOLO model class.
|
| 85 |
+
|
| 86 |
+
This constructor sets up the model based on the provided model path or name. It handles various types of model
|
| 87 |
+
sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
|
| 88 |
+
important attributes of the model and prepares it for operations like training, prediction, or export.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
model (Union[str, Path], optional): The path or model file to load or create. This can be a local
|
| 92 |
+
file path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
|
| 93 |
+
task (Any, optional): The task type associated with the YOLO model, specifying its application domain.
|
| 94 |
+
Defaults to None.
|
| 95 |
+
verbose (bool, optional): If True, enables verbose output during the model's initialization and subsequent
|
| 96 |
+
operations. Defaults to False.
|
| 97 |
+
|
| 98 |
+
Raises:
|
| 99 |
+
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
| 100 |
+
ValueError: If the model file or configuration is invalid or unsupported.
|
| 101 |
+
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
| 102 |
+
"""
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.callbacks = callbacks.get_default_callbacks()
|
| 105 |
+
self.predictor = None # reuse predictor
|
| 106 |
+
self.model = None # model object
|
| 107 |
+
self.trainer = None # trainer object
|
| 108 |
+
self.ckpt = None # if loaded from *.pt
|
| 109 |
+
self.cfg = None # if loaded from *.yaml
|
| 110 |
+
self.ckpt_path = None
|
| 111 |
+
self.overrides = {} # overrides for trainer object
|
| 112 |
+
self.metrics = None # validation/training metrics
|
| 113 |
+
self.session = None # HUB session
|
| 114 |
+
self.task = task # task type
|
| 115 |
+
self.model_name = model = str(model).strip() # strip spaces
|
| 116 |
+
|
| 117 |
+
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
| 118 |
+
if self.is_hub_model(model):
|
| 119 |
+
# Fetch model from HUB
|
| 120 |
+
checks.check_requirements("hub-sdk>0.0.2")
|
| 121 |
+
self.session = self._get_hub_session(model)
|
| 122 |
+
model = self.session.model_file
|
| 123 |
+
|
| 124 |
+
# Check if Triton Server model
|
| 125 |
+
elif self.is_triton_model(model):
|
| 126 |
+
self.model = model
|
| 127 |
+
self.task = task
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
# Load or create new YOLO model
|
| 131 |
+
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
| 132 |
+
if Path(model).suffix in (".yaml", ".yml"):
|
| 133 |
+
self._new(model, task=task)
|
| 134 |
+
else:
|
| 135 |
+
self._load(model, task=task)
|
| 136 |
+
|
| 137 |
+
self.model_name = model
|
| 138 |
+
|
| 139 |
+
def __call__(self, source=None, stream=False, **kwargs):
|
| 140 |
+
"""
|
| 141 |
+
An alias for the predict method, enabling the model instance to be callable.
|
| 142 |
+
|
| 143 |
+
This method simplifies the process of making predictions by allowing the model instance to be called directly
|
| 144 |
+
with the required arguments for prediction.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
|
| 148 |
+
Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None.
|
| 149 |
+
stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
|
| 150 |
+
Defaults to False.
|
| 151 |
+
**kwargs (dict): Additional keyword arguments for configuring the prediction process.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
|
| 155 |
+
"""
|
| 156 |
+
return self.predict(source, stream, **kwargs)
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def _get_hub_session(model: str):
|
| 160 |
+
"""Creates a session for Hub Training."""
|
| 161 |
+
from ultralytics.hub.session import HUBTrainingSession
|
| 162 |
+
|
| 163 |
+
session = HUBTrainingSession(model)
|
| 164 |
+
return session if session.client.authenticated else None
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def is_triton_model(model):
|
| 168 |
+
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
| 169 |
+
from urllib.parse import urlsplit
|
| 170 |
+
|
| 171 |
+
url = urlsplit(model)
|
| 172 |
+
return url.netloc and url.path and url.scheme in {"http", "grpc"}
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def is_hub_model(model):
|
| 176 |
+
"""Check if the provided model is a HUB model."""
|
| 177 |
+
return any(
|
| 178 |
+
(
|
| 179 |
+
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
| 180 |
+
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
|
| 181 |
+
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
|
| 182 |
+
)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def _new(self, cfg: str, task=None, model=None, verbose=True):
|
| 186 |
+
"""
|
| 187 |
+
Initializes a new model and infers the task type from the model definitions.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
cfg (str): model configuration file
|
| 191 |
+
task (str | None): model task
|
| 192 |
+
model (BaseModel): Customized model.
|
| 193 |
+
verbose (bool): display model info on load
|
| 194 |
+
"""
|
| 195 |
+
cfg_dict = yaml_model_load(cfg)
|
| 196 |
+
self.cfg = cfg
|
| 197 |
+
self.task = task or guess_model_task(cfg_dict)
|
| 198 |
+
self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
| 199 |
+
self.overrides["model"] = self.cfg
|
| 200 |
+
self.overrides["task"] = self.task
|
| 201 |
+
|
| 202 |
+
# Below added to allow export from YAMLs
|
| 203 |
+
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
| 204 |
+
self.model.task = self.task
|
| 205 |
+
|
| 206 |
+
def _load(self, weights: str, task=None):
|
| 207 |
+
"""
|
| 208 |
+
Initializes a new model and infers the task type from the model head.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
weights (str): model checkpoint to be loaded
|
| 212 |
+
task (str | None): model task
|
| 213 |
+
"""
|
| 214 |
+
suffix = Path(weights).suffix
|
| 215 |
+
if suffix == ".pt":
|
| 216 |
+
self.model, self.ckpt = attempt_load_one_weight(weights)
|
| 217 |
+
self.task = self.model.args["task"]
|
| 218 |
+
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
| 219 |
+
self.ckpt_path = self.model.pt_path
|
| 220 |
+
else:
|
| 221 |
+
weights = checks.check_file(weights)
|
| 222 |
+
self.model, self.ckpt = weights, None
|
| 223 |
+
self.task = task or guess_model_task(weights)
|
| 224 |
+
self.ckpt_path = weights
|
| 225 |
+
self.overrides["model"] = weights
|
| 226 |
+
self.overrides["task"] = self.task
|
| 227 |
+
|
| 228 |
+
def _check_is_pytorch_model(self):
|
| 229 |
+
"""Raises TypeError is model is not a PyTorch model."""
|
| 230 |
+
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
| 231 |
+
pt_module = isinstance(self.model, nn.Module)
|
| 232 |
+
if not (pt_module or pt_str):
|
| 233 |
+
raise TypeError(
|
| 234 |
+
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
|
| 235 |
+
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
|
| 236 |
+
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
|
| 237 |
+
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
|
| 238 |
+
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def reset_weights(self):
|
| 242 |
+
"""
|
| 243 |
+
Resets the model parameters to randomly initialized values, effectively discarding all training information.
|
| 244 |
+
|
| 245 |
+
This method iterates through all modules in the model and resets their parameters if they have a
|
| 246 |
+
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
|
| 247 |
+
to be updated during training.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
self (ultralytics.engine.model.Model): The instance of the class with reset weights.
|
| 251 |
+
|
| 252 |
+
Raises:
|
| 253 |
+
AssertionError: If the model is not a PyTorch model.
|
| 254 |
+
"""
|
| 255 |
+
self._check_is_pytorch_model()
|
| 256 |
+
for m in self.model.modules():
|
| 257 |
+
if hasattr(m, "reset_parameters"):
|
| 258 |
+
m.reset_parameters()
|
| 259 |
+
for p in self.model.parameters():
|
| 260 |
+
p.requires_grad = True
|
| 261 |
+
return self
|
| 262 |
+
|
| 263 |
+
def load(self, weights="yolov8n.pt"):
|
| 264 |
+
"""
|
| 265 |
+
Loads parameters from the specified weights file into the model.
|
| 266 |
+
|
| 267 |
+
This method supports loading weights from a file or directly from a weights object. It matches parameters by
|
| 268 |
+
name and shape and transfers them to the model.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
weights (str | Path): Path to the weights file or a weights object. Defaults to 'yolov8n.pt'.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
self (ultralytics.engine.model.Model): The instance of the class with loaded weights.
|
| 275 |
+
|
| 276 |
+
Raises:
|
| 277 |
+
AssertionError: If the model is not a PyTorch model.
|
| 278 |
+
"""
|
| 279 |
+
self._check_is_pytorch_model()
|
| 280 |
+
if isinstance(weights, (str, Path)):
|
| 281 |
+
weights, self.ckpt = attempt_load_one_weight(weights)
|
| 282 |
+
self.model.load(weights)
|
| 283 |
+
return self
|
| 284 |
+
|
| 285 |
+
def save(self, filename="model.pt"):
|
| 286 |
+
"""
|
| 287 |
+
Saves the current model state to a file.
|
| 288 |
+
|
| 289 |
+
This method exports the model's checkpoint (ckpt) to the specified filename.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
filename (str): The name of the file to save the model to. Defaults to 'model.pt'.
|
| 293 |
+
|
| 294 |
+
Raises:
|
| 295 |
+
AssertionError: If the model is not a PyTorch model.
|
| 296 |
+
"""
|
| 297 |
+
self._check_is_pytorch_model()
|
| 298 |
+
import torch
|
| 299 |
+
|
| 300 |
+
torch.save(self.ckpt, filename)
|
| 301 |
+
|
| 302 |
+
def info(self, detailed=False, verbose=True):
|
| 303 |
+
"""
|
| 304 |
+
Logs or returns model information.
|
| 305 |
+
|
| 306 |
+
This method provides an overview or detailed information about the model, depending on the arguments passed.
|
| 307 |
+
It can control the verbosity of the output.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
detailed (bool): If True, shows detailed information about the model. Defaults to False.
|
| 311 |
+
verbose (bool): If True, prints the information. If False, returns the information. Defaults to True.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
(list): Various types of information about the model, depending on the 'detailed' and 'verbose' parameters.
|
| 315 |
+
|
| 316 |
+
Raises:
|
| 317 |
+
AssertionError: If the model is not a PyTorch model.
|
| 318 |
+
"""
|
| 319 |
+
self._check_is_pytorch_model()
|
| 320 |
+
return self.model.info(detailed=detailed, verbose=verbose)
|
| 321 |
+
|
| 322 |
+
def fuse(self):
|
| 323 |
+
"""
|
| 324 |
+
Fuses Conv2d and BatchNorm2d layers in the model.
|
| 325 |
+
|
| 326 |
+
This method optimizes the model by fusing Conv2d and BatchNorm2d layers, which can improve inference speed.
|
| 327 |
+
|
| 328 |
+
Raises:
|
| 329 |
+
AssertionError: If the model is not a PyTorch model.
|
| 330 |
+
"""
|
| 331 |
+
self._check_is_pytorch_model()
|
| 332 |
+
self.model.fuse()
|
| 333 |
+
|
| 334 |
+
def embed(self, source=None, stream=False, **kwargs):
|
| 335 |
+
"""
|
| 336 |
+
Generates image embeddings based on the provided source.
|
| 337 |
+
|
| 338 |
+
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image source.
|
| 339 |
+
It allows customization of the embedding process through various keyword arguments.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
source (str | int | PIL.Image | np.ndarray): The source of the image for generating embeddings.
|
| 343 |
+
The source can be a file path, URL, PIL image, numpy array, etc. Defaults to None.
|
| 344 |
+
stream (bool): If True, predictions are streamed. Defaults to False.
|
| 345 |
+
**kwargs (dict): Additional keyword arguments for configuring the embedding process.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
(List[torch.Tensor]): A list containing the image embeddings.
|
| 349 |
+
|
| 350 |
+
Raises:
|
| 351 |
+
AssertionError: If the model is not a PyTorch model.
|
| 352 |
+
"""
|
| 353 |
+
if not kwargs.get("embed"):
|
| 354 |
+
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
| 355 |
+
return self.predict(source, stream, **kwargs)
|
| 356 |
+
|
| 357 |
+
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
| 358 |
+
"""
|
| 359 |
+
Performs predictions on the given image source using the YOLO model.
|
| 360 |
+
|
| 361 |
+
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
| 362 |
+
It supports predictions with custom predictors or the default predictor method. The method handles different
|
| 363 |
+
types of image sources and can operate in a streaming mode. It also provides support for SAM-type models
|
| 364 |
+
through 'prompts'.
|
| 365 |
+
|
| 366 |
+
The method sets up a new predictor if not already present and updates its arguments with each call.
|
| 367 |
+
It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
|
| 368 |
+
is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
|
| 369 |
+
for confidence threshold and saving behavior.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
|
| 373 |
+
Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS.
|
| 374 |
+
stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False.
|
| 375 |
+
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
|
| 376 |
+
If None, the method uses a default predictor. Defaults to None.
|
| 377 |
+
**kwargs (dict): Additional keyword arguments for configuring the prediction process. These arguments allow
|
| 378 |
+
for further customization of the prediction behavior.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
|
| 382 |
+
|
| 383 |
+
Raises:
|
| 384 |
+
AttributeError: If the predictor is not properly set up.
|
| 385 |
+
"""
|
| 386 |
+
if source is None:
|
| 387 |
+
source = ASSETS
|
| 388 |
+
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
| 389 |
+
|
| 390 |
+
is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
|
| 391 |
+
x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
custom = {"conf": 0.25, "save": is_cli, "mode": "predict"} # method defaults
|
| 395 |
+
args = {**self.overrides, **custom, **kwargs} # highest priority args on the right
|
| 396 |
+
prompts = args.pop("prompts", None) # for SAM-type models
|
| 397 |
+
|
| 398 |
+
if not self.predictor:
|
| 399 |
+
self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
|
| 400 |
+
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
| 401 |
+
else: # only update args if predictor is already setup
|
| 402 |
+
self.predictor.args = get_cfg(self.predictor.args, args)
|
| 403 |
+
if "project" in args or "name" in args:
|
| 404 |
+
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
| 405 |
+
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
|
| 406 |
+
self.predictor.set_prompts(prompts)
|
| 407 |
+
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
| 408 |
+
|
| 409 |
+
def track(self, source=None, stream=False, persist=False, **kwargs):
|
| 410 |
+
"""
|
| 411 |
+
Conducts object tracking on the specified input source using the registered trackers.
|
| 412 |
+
|
| 413 |
+
This method performs object tracking using the model's predictors and optionally registered trackers. It is
|
| 414 |
+
capable of handling different types of input sources such as file paths or video streams. The method supports
|
| 415 |
+
customization of the tracking process through various keyword arguments. It registers trackers if they are not
|
| 416 |
+
already present and optionally persists them based on the 'persist' flag.
|
| 417 |
+
|
| 418 |
+
The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low
|
| 419 |
+
confidence predictions as input. The tracking mode is explicitly set in the keyword arguments.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
source (str, optional): The input source for object tracking. It can be a file path, URL, or video stream.
|
| 423 |
+
stream (bool, optional): Treats the input source as a continuous video stream. Defaults to False.
|
| 424 |
+
persist (bool, optional): Persists the trackers between different calls to this method. Defaults to False.
|
| 425 |
+
**kwargs (dict): Additional keyword arguments for configuring the tracking process. These arguments allow
|
| 426 |
+
for further customization of the tracking behavior.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
(List[ultralytics.engine.results.Results]): A list of tracking results, encapsulated in the Results class.
|
| 430 |
+
|
| 431 |
+
Raises:
|
| 432 |
+
AttributeError: If the predictor does not have registered trackers.
|
| 433 |
+
"""
|
| 434 |
+
if not hasattr(self.predictor, "trackers"):
|
| 435 |
+
from ultralytics.trackers import register_tracker
|
| 436 |
+
|
| 437 |
+
register_tracker(self, persist)
|
| 438 |
+
kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input
|
| 439 |
+
kwargs["mode"] = "track"
|
| 440 |
+
return self.predict(source=source, stream=stream, **kwargs)
|
| 441 |
+
|
| 442 |
+
def val(self, validator=None, **kwargs):
|
| 443 |
+
"""
|
| 444 |
+
Validates the model using a specified dataset and validation configuration.
|
| 445 |
+
|
| 446 |
+
This method facilitates the model validation process, allowing for a range of customization through various
|
| 447 |
+
settings and configurations. It supports validation with a custom validator or the default validation approach.
|
| 448 |
+
The method combines default configurations, method-specific defaults, and user-provided arguments to configure
|
| 449 |
+
the validation process. After validation, it updates the model's metrics with the results obtained from the
|
| 450 |
+
validator.
|
| 451 |
+
|
| 452 |
+
The method supports various arguments that allow customization of the validation process. For a comprehensive
|
| 453 |
+
list of all configurable options, users should refer to the 'configuration' section in the documentation.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
validator (BaseValidator, optional): An instance of a custom validator class for validating the model. If
|
| 457 |
+
None, the method uses a default validator. Defaults to None.
|
| 458 |
+
**kwargs (dict): Arbitrary keyword arguments representing the validation configuration. These arguments are
|
| 459 |
+
used to customize various aspects of the validation process.
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
(dict): Validation metrics obtained from the validation process.
|
| 463 |
+
|
| 464 |
+
Raises:
|
| 465 |
+
AssertionError: If the model is not a PyTorch model.
|
| 466 |
+
"""
|
| 467 |
+
custom = {"rect": True} # method defaults
|
| 468 |
+
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
| 469 |
+
|
| 470 |
+
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
|
| 471 |
+
validator(model=self.model)
|
| 472 |
+
self.metrics = validator.metrics
|
| 473 |
+
return validator.metrics
|
| 474 |
+
|
| 475 |
+
def benchmark(self, **kwargs):
|
| 476 |
+
"""
|
| 477 |
+
Benchmarks the model across various export formats to evaluate performance.
|
| 478 |
+
|
| 479 |
+
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
|
| 480 |
+
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured
|
| 481 |
+
using a combination of default configuration values, model-specific arguments, method-specific defaults, and
|
| 482 |
+
any additional user-provided keyword arguments.
|
| 483 |
+
|
| 484 |
+
The method supports various arguments that allow customization of the benchmarking process, such as dataset
|
| 485 |
+
choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all
|
| 486 |
+
configurable options, users should refer to the 'configuration' section in the documentation.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
**kwargs (dict): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
|
| 490 |
+
default configurations, model-specific arguments, and method defaults.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
(dict): A dictionary containing the results of the benchmarking process.
|
| 494 |
+
|
| 495 |
+
Raises:
|
| 496 |
+
AssertionError: If the model is not a PyTorch model.
|
| 497 |
+
"""
|
| 498 |
+
self._check_is_pytorch_model()
|
| 499 |
+
from ultralytics.utils.benchmarks import benchmark
|
| 500 |
+
|
| 501 |
+
custom = {"verbose": False} # method defaults
|
| 502 |
+
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
| 503 |
+
return benchmark(
|
| 504 |
+
model=self,
|
| 505 |
+
data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
|
| 506 |
+
imgsz=args["imgsz"],
|
| 507 |
+
half=args["half"],
|
| 508 |
+
int8=args["int8"],
|
| 509 |
+
device=args["device"],
|
| 510 |
+
verbose=kwargs.get("verbose"),
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
def export(self, **kwargs):
|
| 514 |
+
"""
|
| 515 |
+
Exports the model to a different format suitable for deployment.
|
| 516 |
+
|
| 517 |
+
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
|
| 518 |
+
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
|
| 519 |
+
defaults, and any additional arguments provided. The combined arguments are used to configure export settings.
|
| 520 |
+
|
| 521 |
+
The method supports a wide range of arguments to customize the export process. For a comprehensive list of all
|
| 522 |
+
possible arguments, refer to the 'configuration' section in the documentation.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
**kwargs (dict): Arbitrary keyword arguments to customize the export process. These are combined with the
|
| 526 |
+
model's overrides and method defaults.
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
(object): The exported model in the specified format, or an object related to the export process.
|
| 530 |
+
|
| 531 |
+
Raises:
|
| 532 |
+
AssertionError: If the model is not a PyTorch model.
|
| 533 |
+
"""
|
| 534 |
+
self._check_is_pytorch_model()
|
| 535 |
+
from .exporter import Exporter
|
| 536 |
+
|
| 537 |
+
custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
|
| 538 |
+
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
|
| 539 |
+
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
| 540 |
+
|
| 541 |
+
def train(self, trainer=None, **kwargs):
|
| 542 |
+
"""
|
| 543 |
+
Trains the model using the specified dataset and training configuration.
|
| 544 |
+
|
| 545 |
+
This method facilitates model training with a range of customizable settings and configurations. It supports
|
| 546 |
+
training with a custom trainer or the default training approach defined in the method. The method handles
|
| 547 |
+
different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
|
| 548 |
+
updating model and configuration after training.
|
| 549 |
+
|
| 550 |
+
When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
|
| 551 |
+
arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
|
| 552 |
+
configurations, method-specific defaults, and user-provided arguments to configure the training process. After
|
| 553 |
+
training, it updates the model and its configurations, and optionally attaches metrics.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
|
| 557 |
+
method uses a default trainer. Defaults to None.
|
| 558 |
+
**kwargs (dict): Arbitrary keyword arguments representing the training configuration. These arguments are
|
| 559 |
+
used to customize various aspects of the training process.
|
| 560 |
+
|
| 561 |
+
Returns:
|
| 562 |
+
(dict | None): Training metrics if available and training is successful; otherwise, None.
|
| 563 |
+
|
| 564 |
+
Raises:
|
| 565 |
+
AssertionError: If the model is not a PyTorch model.
|
| 566 |
+
PermissionError: If there is a permission issue with the HUB session.
|
| 567 |
+
ModuleNotFoundError: If the HUB SDK is not installed.
|
| 568 |
+
"""
|
| 569 |
+
self._check_is_pytorch_model()
|
| 570 |
+
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
| 571 |
+
if any(kwargs):
|
| 572 |
+
LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
|
| 573 |
+
kwargs = self.session.train_args # overwrite kwargs
|
| 574 |
+
|
| 575 |
+
checks.check_pip_update_available()
|
| 576 |
+
|
| 577 |
+
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
|
| 578 |
+
custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
|
| 579 |
+
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
| 580 |
+
# if args.get("resume"):
|
| 581 |
+
# args["resume"] = self.ckpt_path
|
| 582 |
+
|
| 583 |
+
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
|
| 584 |
+
if not args.get("resume"): # manually set model only if not resuming
|
| 585 |
+
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
| 586 |
+
self.model = self.trainer.model
|
| 587 |
+
|
| 588 |
+
if SETTINGS["hub"] is True and not self.session:
|
| 589 |
+
# Create a model in HUB
|
| 590 |
+
try:
|
| 591 |
+
self.session = self._get_hub_session(self.model_name)
|
| 592 |
+
if self.session:
|
| 593 |
+
self.session.create_model(args)
|
| 594 |
+
# Check model was created
|
| 595 |
+
if not getattr(self.session.model, "id", None):
|
| 596 |
+
self.session = None
|
| 597 |
+
except (PermissionError, ModuleNotFoundError):
|
| 598 |
+
# Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
|
| 599 |
+
pass
|
| 600 |
+
|
| 601 |
+
self.trainer.hub_session = self.session # attach optional HUB session
|
| 602 |
+
self.trainer.train()
|
| 603 |
+
# Update model and cfg after training
|
| 604 |
+
if RANK in (-1, 0):
|
| 605 |
+
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
| 606 |
+
self.model, _ = attempt_load_one_weight(ckpt)
|
| 607 |
+
self.overrides = self.model.args
|
| 608 |
+
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
| 609 |
+
return self.metrics
|
| 610 |
+
|
| 611 |
+
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
|
| 612 |
+
"""
|
| 613 |
+
Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
|
| 614 |
+
|
| 615 |
+
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
|
| 616 |
+
When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.
|
| 617 |
+
Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and
|
| 618 |
+
custom arguments to configure the tuning process.
|
| 619 |
+
|
| 620 |
+
Args:
|
| 621 |
+
use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
|
| 622 |
+
iterations (int): The number of tuning iterations to perform. Defaults to 10.
|
| 623 |
+
*args (list): Variable length argument list for additional arguments.
|
| 624 |
+
**kwargs (dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
|
| 625 |
+
|
| 626 |
+
Returns:
|
| 627 |
+
(dict): A dictionary containing the results of the hyperparameter search.
|
| 628 |
+
|
| 629 |
+
Raises:
|
| 630 |
+
AssertionError: If the model is not a PyTorch model.
|
| 631 |
+
"""
|
| 632 |
+
self._check_is_pytorch_model()
|
| 633 |
+
if use_ray:
|
| 634 |
+
from ultralytics.utils.tuner import run_ray_tune
|
| 635 |
+
|
| 636 |
+
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
|
| 637 |
+
else:
|
| 638 |
+
from .tuner import Tuner
|
| 639 |
+
|
| 640 |
+
custom = {} # method defaults
|
| 641 |
+
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
| 642 |
+
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
| 643 |
+
|
| 644 |
+
def _apply(self, fn):
|
| 645 |
+
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
|
| 646 |
+
self._check_is_pytorch_model()
|
| 647 |
+
self = super()._apply(fn) # noqa
|
| 648 |
+
self.predictor = None # reset predictor as device may have changed
|
| 649 |
+
self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
|
| 650 |
+
return self
|
| 651 |
+
|
| 652 |
+
@property
|
| 653 |
+
def names(self):
|
| 654 |
+
"""
|
| 655 |
+
Retrieves the class names associated with the loaded model.
|
| 656 |
+
|
| 657 |
+
This property returns the class names if they are defined in the model. It checks the class names for validity
|
| 658 |
+
using the 'check_class_names' function from the ultralytics.nn.autobackend module.
|
| 659 |
+
|
| 660 |
+
Returns:
|
| 661 |
+
(list | None): The class names of the model if available, otherwise None.
|
| 662 |
+
"""
|
| 663 |
+
from ultralytics.nn.autobackend import check_class_names
|
| 664 |
+
|
| 665 |
+
return check_class_names(self.model.names) if hasattr(self.model, "names") else None
|
| 666 |
+
|
| 667 |
+
@property
|
| 668 |
+
def device(self):
|
| 669 |
+
"""
|
| 670 |
+
Retrieves the device on which the model's parameters are allocated.
|
| 671 |
+
|
| 672 |
+
This property is used to determine whether the model's parameters are on CPU or GPU. It only applies to models
|
| 673 |
+
that are instances of nn.Module.
|
| 674 |
+
|
| 675 |
+
Returns:
|
| 676 |
+
(torch.device | None): The device (CPU/GPU) of the model if it is a PyTorch model, otherwise None.
|
| 677 |
+
"""
|
| 678 |
+
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
| 679 |
+
|
| 680 |
+
@property
|
| 681 |
+
def transforms(self):
|
| 682 |
+
"""
|
| 683 |
+
Retrieves the transformations applied to the input data of the loaded model.
|
| 684 |
+
|
| 685 |
+
This property returns the transformations if they are defined in the model.
|
| 686 |
+
|
| 687 |
+
Returns:
|
| 688 |
+
(object | None): The transform object of the model if available, otherwise None.
|
| 689 |
+
"""
|
| 690 |
+
return self.model.transforms if hasattr(self.model, "transforms") else None
|
| 691 |
+
|
| 692 |
+
def add_callback(self, event: str, func):
|
| 693 |
+
"""
|
| 694 |
+
Adds a callback function for a specified event.
|
| 695 |
+
|
| 696 |
+
This method allows the user to register a custom callback function that is triggered on a specific event during
|
| 697 |
+
model training or inference.
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
event (str): The name of the event to attach the callback to.
|
| 701 |
+
func (callable): The callback function to be registered.
|
| 702 |
+
|
| 703 |
+
Raises:
|
| 704 |
+
ValueError: If the event name is not recognized.
|
| 705 |
+
"""
|
| 706 |
+
self.callbacks[event].append(func)
|
| 707 |
+
|
| 708 |
+
def clear_callback(self, event: str):
|
| 709 |
+
"""
|
| 710 |
+
Clears all callback functions registered for a specified event.
|
| 711 |
+
|
| 712 |
+
This method removes all custom and default callback functions associated with the given event.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
event (str): The name of the event for which to clear the callbacks.
|
| 716 |
+
|
| 717 |
+
Raises:
|
| 718 |
+
ValueError: If the event name is not recognized.
|
| 719 |
+
"""
|
| 720 |
+
self.callbacks[event] = []
|
| 721 |
+
|
| 722 |
+
def reset_callbacks(self):
|
| 723 |
+
"""
|
| 724 |
+
Resets all callbacks to their default functions.
|
| 725 |
+
|
| 726 |
+
This method reinstates the default callback functions for all events, removing any custom callbacks that were
|
| 727 |
+
added previously.
|
| 728 |
+
"""
|
| 729 |
+
for event in callbacks.default_callbacks.keys():
|
| 730 |
+
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
| 731 |
+
|
| 732 |
+
@staticmethod
|
| 733 |
+
def _reset_ckpt_args(args):
|
| 734 |
+
"""Reset arguments when loading a PyTorch model."""
|
| 735 |
+
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
| 736 |
+
return {k: v for k, v in args.items() if k in include}
|
| 737 |
+
|
| 738 |
+
# def __getattr__(self, attr):
|
| 739 |
+
# """Raises error if object has no requested attribute."""
|
| 740 |
+
# name = self.__class__.__name__
|
| 741 |
+
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
| 742 |
+
|
| 743 |
+
def _smart_load(self, key):
|
| 744 |
+
"""Load model/trainer/validator/predictor."""
|
| 745 |
+
try:
|
| 746 |
+
return self.task_map[self.task][key]
|
| 747 |
+
except Exception as e:
|
| 748 |
+
name = self.__class__.__name__
|
| 749 |
+
mode = inspect.stack()[1][3] # get the function name.
|
| 750 |
+
raise NotImplementedError(
|
| 751 |
+
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
|
| 752 |
+
) from e
|
| 753 |
+
|
| 754 |
+
@property
|
| 755 |
+
def task_map(self):
|
| 756 |
+
"""
|
| 757 |
+
Map head to model, trainer, validator, and predictor classes.
|
| 758 |
+
|
| 759 |
+
Returns:
|
| 760 |
+
task_map (dict): The map of model task to mode classes.
|
| 761 |
+
"""
|
| 762 |
+
raise NotImplementedError("Please provide task map for your model!")
|
| 763 |
+
|
| 764 |
+
def profile(self, imgsz):
|
| 765 |
+
if type(imgsz) is int:
|
| 766 |
+
inputs = torch.randn((2, 3, imgsz, imgsz))
|
| 767 |
+
else:
|
| 768 |
+
inputs = torch.randn((2, 3, imgsz[0], imgsz[1]))
|
| 769 |
+
if next(self.model.parameters()).device.type == 'cuda':
|
| 770 |
+
return self.model.predict(inputs.to(torch.device('cuda')), profile=True)
|
| 771 |
+
else:
|
| 772 |
+
self.model.predict(inputs, profile=True)
|
yolov8_model/ultralytics/engine/predictor.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
"""
|
| 3 |
+
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
| 4 |
+
|
| 5 |
+
Usage - sources:
|
| 6 |
+
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
|
| 7 |
+
img.jpg # image
|
| 8 |
+
vid.mp4 # video
|
| 9 |
+
screen # screenshot
|
| 10 |
+
path/ # directory
|
| 11 |
+
list.txt # list of images
|
| 12 |
+
list.streams # list of streams
|
| 13 |
+
'path/*.jpg' # glob
|
| 14 |
+
'https://youtu.be/LNwODJXcvt4' # YouTube
|
| 15 |
+
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
|
| 16 |
+
|
| 17 |
+
Usage - formats:
|
| 18 |
+
$ yolo mode=predict model=yolov8n.pt # PyTorch
|
| 19 |
+
yolov8n.torchscript # TorchScript
|
| 20 |
+
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
| 21 |
+
yolov8n_openvino_model # OpenVINO
|
| 22 |
+
yolov8n.engine # TensorRT
|
| 23 |
+
yolov8n.mlpackage # CoreML (macOS-only)
|
| 24 |
+
yolov8n_saved_model # TensorFlow SavedModel
|
| 25 |
+
yolov8n.pb # TensorFlow GraphDef
|
| 26 |
+
yolov8n.tflite # TensorFlow Lite
|
| 27 |
+
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
| 28 |
+
yolov8n_paddle_model # PaddlePaddle
|
| 29 |
+
"""
|
| 30 |
+
import platform
|
| 31 |
+
import threading
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
|
| 34 |
+
import cv2
|
| 35 |
+
import numpy as np
|
| 36 |
+
import torch
|
| 37 |
+
from PIL import Image
|
| 38 |
+
from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
|
| 39 |
+
from yolov8_model.ultralytics.data import load_inference_source
|
| 40 |
+
from yolov8_model.ultralytics.data.augment import LetterBox, classify_transforms
|
| 41 |
+
from yolov8_model.ultralytics.nn.autobackend import AutoBackend
|
| 42 |
+
from yolov8_model.ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
|
| 43 |
+
from yolov8_model.ultralytics.utils.checks import check_imgsz, check_imshow
|
| 44 |
+
from yolov8_model.ultralytics.utils.files import increment_path
|
| 45 |
+
from yolov8_model.ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
| 46 |
+
|
| 47 |
+
STREAM_WARNING = """
|
| 48 |
+
WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
|
| 49 |
+
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
results = model(source=..., stream=True) # generator of Results objects
|
| 53 |
+
for r in results:
|
| 54 |
+
boxes = r.boxes # Boxes object for bbox outputs
|
| 55 |
+
masks = r.masks # Masks object for segment masks outputs
|
| 56 |
+
probs = r.probs # Class probabilities for classification outputs
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BasePredictor:
|
| 61 |
+
"""
|
| 62 |
+
BasePredictor.
|
| 63 |
+
|
| 64 |
+
A base class for creating predictors.
|
| 65 |
+
|
| 66 |
+
Attributes:
|
| 67 |
+
args (SimpleNamespace): Configuration for the predictor.
|
| 68 |
+
save_dir (Path): Directory to save results.
|
| 69 |
+
done_warmup (bool): Whether the predictor has finished setup.
|
| 70 |
+
model (nn.Module): Model used for prediction.
|
| 71 |
+
data (dict): Data configuration.
|
| 72 |
+
device (torch.device): Device used for prediction.
|
| 73 |
+
dataset (Dataset): Dataset used for prediction.
|
| 74 |
+
vid_path (str): Path to video file.
|
| 75 |
+
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
| 76 |
+
data_path (str): Path to data.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
| 80 |
+
"""
|
| 81 |
+
Initializes the BasePredictor class.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
| 85 |
+
overrides (dict, optional): Configuration overrides. Defaults to None.
|
| 86 |
+
"""
|
| 87 |
+
self.args = get_cfg(cfg, overrides)
|
| 88 |
+
self.save_dir = get_save_dir(self.args)
|
| 89 |
+
if self.args.conf is None:
|
| 90 |
+
self.args.conf = 0.25 # default conf=0.25
|
| 91 |
+
self.done_warmup = False
|
| 92 |
+
if self.args.show:
|
| 93 |
+
self.args.show = check_imshow(warn=True)
|
| 94 |
+
|
| 95 |
+
# Usable if setup is done
|
| 96 |
+
self.model = None
|
| 97 |
+
self.data = self.args.data # data_dict
|
| 98 |
+
self.imgsz = None
|
| 99 |
+
self.device = None
|
| 100 |
+
self.dataset = None
|
| 101 |
+
self.vid_path, self.vid_writer, self.vid_frame = None, None, None
|
| 102 |
+
self.plotted_img = None
|
| 103 |
+
self.data_path = None
|
| 104 |
+
self.source_type = None
|
| 105 |
+
self.batch = None
|
| 106 |
+
self.results = None
|
| 107 |
+
self.transforms = None
|
| 108 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
| 109 |
+
self.txt_path = None
|
| 110 |
+
self._lock = threading.Lock() # for automatic thread-safe inference
|
| 111 |
+
callbacks.add_integration_callbacks(self)
|
| 112 |
+
|
| 113 |
+
def preprocess(self, im):
|
| 114 |
+
"""
|
| 115 |
+
Prepares input image before inference.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
| 119 |
+
"""
|
| 120 |
+
not_tensor = not isinstance(im, torch.Tensor)
|
| 121 |
+
if not_tensor:
|
| 122 |
+
im = np.stack(self.pre_transform(im))
|
| 123 |
+
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
| 124 |
+
im = np.ascontiguousarray(im) # contiguous
|
| 125 |
+
im = torch.from_numpy(im)
|
| 126 |
+
|
| 127 |
+
im = im.to(self.device)
|
| 128 |
+
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
| 129 |
+
if not_tensor:
|
| 130 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
| 131 |
+
return im
|
| 132 |
+
|
| 133 |
+
def inference(self, im, *args, **kwargs):
|
| 134 |
+
"""Runs inference on a given image using the specified model and arguments."""
|
| 135 |
+
visualize = (
|
| 136 |
+
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
| 137 |
+
if self.args.visualize and (not self.source_type.tensor)
|
| 138 |
+
else False
|
| 139 |
+
)
|
| 140 |
+
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
| 141 |
+
|
| 142 |
+
def pre_transform(self, im):
|
| 143 |
+
"""
|
| 144 |
+
Pre-transform input image before inference.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
(list): A list of transformed images.
|
| 151 |
+
"""
|
| 152 |
+
same_shapes = all(x.shape == im[0].shape for x in im)
|
| 153 |
+
letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
|
| 154 |
+
return [letterbox(image=x) for x in im]
|
| 155 |
+
|
| 156 |
+
def write_results(self, idx, results, batch):
|
| 157 |
+
"""Write inference results to a file or directory."""
|
| 158 |
+
p, im, _ = batch
|
| 159 |
+
log_string = ""
|
| 160 |
+
if len(im.shape) == 3:
|
| 161 |
+
im = im[None] # expand for batch dim
|
| 162 |
+
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
| 163 |
+
log_string += f"{idx}: "
|
| 164 |
+
frame = self.dataset.count
|
| 165 |
+
else:
|
| 166 |
+
frame = getattr(self.dataset, "frame", 0)
|
| 167 |
+
self.data_path = p
|
| 168 |
+
self.txt_path = str(self.save_dir / "labels" / p.stem) + ("" if self.dataset.mode == "image" else f"_{frame}")
|
| 169 |
+
log_string += "%gx%g " % im.shape[2:] # print string
|
| 170 |
+
result = results[idx]
|
| 171 |
+
log_string += result.verbose()
|
| 172 |
+
|
| 173 |
+
if self.args.save or self.args.show: # Add bbox to image
|
| 174 |
+
plot_args = {
|
| 175 |
+
"line_width": self.args.line_width,
|
| 176 |
+
"boxes": self.args.show_boxes,
|
| 177 |
+
"conf": self.args.show_conf,
|
| 178 |
+
"labels": self.args.show_labels,
|
| 179 |
+
}
|
| 180 |
+
if not self.args.retina_masks:
|
| 181 |
+
plot_args["im_gpu"] = im[idx]
|
| 182 |
+
self.plotted_img = result.plot(**plot_args)
|
| 183 |
+
# Write
|
| 184 |
+
if self.args.save_txt:
|
| 185 |
+
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
|
| 186 |
+
if self.args.save_crop:
|
| 187 |
+
result.save_crop(
|
| 188 |
+
save_dir=self.save_dir / "crops",
|
| 189 |
+
file_name=self.data_path.stem + ("" if self.dataset.mode == "image" else f"_{frame}"),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return log_string
|
| 193 |
+
|
| 194 |
+
def postprocess(self, preds, img, orig_imgs):
|
| 195 |
+
"""Post-processes predictions for an image and returns them."""
|
| 196 |
+
return preds
|
| 197 |
+
|
| 198 |
+
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
| 199 |
+
"""Performs inference on an image or stream."""
|
| 200 |
+
self.stream = stream
|
| 201 |
+
if stream:
|
| 202 |
+
return self.stream_inference(source, model, *args, **kwargs)
|
| 203 |
+
else:
|
| 204 |
+
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
| 205 |
+
|
| 206 |
+
def predict_cli(self, source=None, model=None):
|
| 207 |
+
"""
|
| 208 |
+
Method used for CLI prediction.
|
| 209 |
+
|
| 210 |
+
It uses always generator as outputs as not required by CLI mode.
|
| 211 |
+
"""
|
| 212 |
+
gen = self.stream_inference(source, model)
|
| 213 |
+
for _ in gen: # noqa, running CLI inference without accumulating any outputs (do not modify)
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
def setup_source(self, source):
|
| 217 |
+
"""Sets up source and inference mode."""
|
| 218 |
+
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
| 219 |
+
self.transforms = (
|
| 220 |
+
getattr(
|
| 221 |
+
self.model.model,
|
| 222 |
+
"transforms",
|
| 223 |
+
classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
|
| 224 |
+
)
|
| 225 |
+
if self.args.task == "classify"
|
| 226 |
+
else None
|
| 227 |
+
)
|
| 228 |
+
self.dataset = load_inference_source(
|
| 229 |
+
source=source, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
|
| 230 |
+
)
|
| 231 |
+
self.source_type = self.dataset.source_type
|
| 232 |
+
if not getattr(self, "stream", True) and (
|
| 233 |
+
self.dataset.mode == "stream" # streams
|
| 234 |
+
or len(self.dataset) > 1000 # images
|
| 235 |
+
or any(getattr(self.dataset, "video_flag", [False]))
|
| 236 |
+
): # videos
|
| 237 |
+
LOGGER.warning(STREAM_WARNING)
|
| 238 |
+
self.vid_path = [None] * self.dataset.bs
|
| 239 |
+
self.vid_writer = [None] * self.dataset.bs
|
| 240 |
+
self.vid_frame = [None] * self.dataset.bs
|
| 241 |
+
|
| 242 |
+
@smart_inference_mode()
|
| 243 |
+
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
| 244 |
+
"""Streams real-time inference on camera feed and saves results to file."""
|
| 245 |
+
if self.args.verbose:
|
| 246 |
+
LOGGER.info("")
|
| 247 |
+
|
| 248 |
+
# Setup model
|
| 249 |
+
if not self.model:
|
| 250 |
+
self.setup_model(model)
|
| 251 |
+
|
| 252 |
+
with self._lock: # for thread-safe inference
|
| 253 |
+
# Setup source every time predict is called
|
| 254 |
+
self.setup_source(source if source is not None else self.args.source)
|
| 255 |
+
|
| 256 |
+
# Check if save_dir/ label file exists
|
| 257 |
+
if self.args.save or self.args.save_txt:
|
| 258 |
+
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
| 259 |
+
|
| 260 |
+
# Warmup model
|
| 261 |
+
if not self.done_warmup:
|
| 262 |
+
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
| 263 |
+
self.done_warmup = True
|
| 264 |
+
|
| 265 |
+
self.seen, self.windows, self.batch = 0, [], None
|
| 266 |
+
profilers = (
|
| 267 |
+
ops.Profile(device=self.device),
|
| 268 |
+
ops.Profile(device=self.device),
|
| 269 |
+
ops.Profile(device=self.device),
|
| 270 |
+
)
|
| 271 |
+
self.run_callbacks("on_predict_start")
|
| 272 |
+
all_results = []
|
| 273 |
+
|
| 274 |
+
for batch in self.dataset:
|
| 275 |
+
self.run_callbacks("on_predict_batch_start")
|
| 276 |
+
self.batch = batch
|
| 277 |
+
path, im0s, vid_cap, s = batch
|
| 278 |
+
|
| 279 |
+
# Preprocess
|
| 280 |
+
with profilers[0]:
|
| 281 |
+
im = self.preprocess(im0s)
|
| 282 |
+
|
| 283 |
+
# Inference
|
| 284 |
+
with profilers[1]:
|
| 285 |
+
preds = self.inference(im, *args, **kwargs)
|
| 286 |
+
if self.args.embed:
|
| 287 |
+
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
# Postprocess
|
| 291 |
+
with profilers[2]:
|
| 292 |
+
self.results = self.postprocess(preds, im, im0s)
|
| 293 |
+
|
| 294 |
+
self.run_callbacks("on_predict_postprocess_end")
|
| 295 |
+
# Visualize, save, write results
|
| 296 |
+
n = len(im0s)
|
| 297 |
+
for i in range(n):
|
| 298 |
+
self.seen += 1
|
| 299 |
+
self.results[i].speed = {
|
| 300 |
+
"preprocess": profilers[0].dt * 1e3 / n,
|
| 301 |
+
"inference": profilers[1].dt * 1e3 / n,
|
| 302 |
+
"postprocess": profilers[2].dt * 1e3 / n,
|
| 303 |
+
}
|
| 304 |
+
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
| 305 |
+
p = Path(p)
|
| 306 |
+
|
| 307 |
+
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
| 308 |
+
s += self.write_results(i, self.results, (p, im, im0))
|
| 309 |
+
if self.args.save or self.args.save_txt:
|
| 310 |
+
self.results[i].save_dir = self.save_dir.__str__()
|
| 311 |
+
if self.args.show and self.plotted_img is not None:
|
| 312 |
+
self.show(p)
|
| 313 |
+
if self.args.save and self.plotted_img is not None:
|
| 314 |
+
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
| 315 |
+
|
| 316 |
+
self.run_callbacks("on_predict_batch_end")
|
| 317 |
+
yield from self.results
|
| 318 |
+
all_results.extend(self.results)
|
| 319 |
+
# Print time (inference-only)
|
| 320 |
+
if self.args.verbose:
|
| 321 |
+
LOGGER.info(f"{s}{profilers[1].dt * 1E3:.1f}ms")
|
| 322 |
+
|
| 323 |
+
# Release assets
|
| 324 |
+
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
| 325 |
+
self.vid_writer[-1].release() # release final video writer
|
| 326 |
+
|
| 327 |
+
# Print results
|
| 328 |
+
if self.args.verbose and self.seen:
|
| 329 |
+
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
|
| 330 |
+
LOGGER.info(
|
| 331 |
+
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
|
| 332 |
+
f"{(1, 3, *im.shape[2:])}" % t
|
| 333 |
+
)
|
| 334 |
+
if self.args.save or self.args.save_txt or self.args.save_crop:
|
| 335 |
+
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
|
| 336 |
+
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
|
| 337 |
+
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
| 338 |
+
|
| 339 |
+
self.run_callbacks("on_predict_end")
|
| 340 |
+
return all_results
|
| 341 |
+
def setup_model(self, model, verbose=True):
|
| 342 |
+
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
| 343 |
+
self.model = AutoBackend(
|
| 344 |
+
model or self.args.model,
|
| 345 |
+
device=select_device(self.args.device, verbose=verbose),
|
| 346 |
+
dnn=self.args.dnn,
|
| 347 |
+
data=self.args.data,
|
| 348 |
+
fp16=self.args.half,
|
| 349 |
+
fuse=True,
|
| 350 |
+
verbose=verbose,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
self.device = self.model.device # update device
|
| 354 |
+
self.args.half = self.model.fp16 # update half
|
| 355 |
+
self.model.eval()
|
| 356 |
+
|
| 357 |
+
def show(self, p):
|
| 358 |
+
"""Display an image in a window using OpenCV imshow()."""
|
| 359 |
+
im0 = self.plotted_img
|
| 360 |
+
if platform.system() == "Linux" and p not in self.windows:
|
| 361 |
+
self.windows.append(p)
|
| 362 |
+
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
| 363 |
+
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
| 364 |
+
cv2.imshow(str(p), im0)
|
| 365 |
+
cv2.waitKey(500 if self.batch[3].startswith("image") else 1) # 1 millisecond
|
| 366 |
+
|
| 367 |
+
def save_preds(self, vid_cap, idx, save_path):
|
| 368 |
+
"""Save video predictions as mp4 at specified path."""
|
| 369 |
+
im0 = self.plotted_img
|
| 370 |
+
# Save imgs
|
| 371 |
+
if self.dataset.mode == "image":
|
| 372 |
+
cv2.imwrite(save_path, im0)
|
| 373 |
+
else: # 'video' or 'stream'
|
| 374 |
+
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
|
| 375 |
+
if self.vid_path[idx] != save_path: # new video
|
| 376 |
+
self.vid_path[idx] = save_path
|
| 377 |
+
if self.args.save_frames:
|
| 378 |
+
Path(frames_path).mkdir(parents=True, exist_ok=True)
|
| 379 |
+
self.vid_frame[idx] = 0
|
| 380 |
+
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
|
| 381 |
+
self.vid_writer[idx].release() # release previous video writer
|
| 382 |
+
if vid_cap: # video
|
| 383 |
+
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
|
| 384 |
+
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 385 |
+
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 386 |
+
else: # stream
|
| 387 |
+
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
| 388 |
+
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
|
| 389 |
+
self.vid_writer[idx] = cv2.VideoWriter(
|
| 390 |
+
str(Path(save_path).with_suffix(suffix)), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)
|
| 391 |
+
)
|
| 392 |
+
# Write video
|
| 393 |
+
self.vid_writer[idx].write(im0)
|
| 394 |
+
|
| 395 |
+
# Write frame
|
| 396 |
+
if self.args.save_frames:
|
| 397 |
+
cv2.imwrite(f"{frames_path}{self.vid_frame[idx]}.jpg", im0)
|
| 398 |
+
self.vid_frame[idx] += 1
|
| 399 |
+
|
| 400 |
+
def run_callbacks(self, event: str):
|
| 401 |
+
"""Runs all registered callbacks for a specific event."""
|
| 402 |
+
for callback in self.callbacks.get(event, []):
|
| 403 |
+
callback(self)
|
| 404 |
+
|
| 405 |
+
def add_callback(self, event: str, func):
|
| 406 |
+
"""Add callback."""
|
| 407 |
+
self.callbacks[event].append(func)
|
yolov8_model/ultralytics/engine/results.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
"""
|
| 3 |
+
Ultralytics Results, Boxes and Masks classes for handling inference results.
|
| 4 |
+
|
| 5 |
+
Usage: See https://docs.ultralytics.com/modes/predict/
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from yolov8_model.ultralytics.data.augment import LetterBox
|
| 16 |
+
from yolov8_model.ultralytics.utils import LOGGER, SimpleClass, ops
|
| 17 |
+
from yolov8_model.ultralytics.utils.plotting import Annotator, colors, save_one_box
|
| 18 |
+
from yolov8_model.ultralytics.utils.torch_utils import smart_inference_mode
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseTensor(SimpleClass):
|
| 22 |
+
"""Base tensor class with additional methods for easy manipulation and device handling."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, data, orig_shape) -> None:
|
| 25 |
+
"""
|
| 26 |
+
Initialize BaseTensor with data and original shape.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
|
| 30 |
+
orig_shape (tuple): Original shape of image.
|
| 31 |
+
"""
|
| 32 |
+
assert isinstance(data, (torch.Tensor, np.ndarray))
|
| 33 |
+
self.data = data
|
| 34 |
+
self.orig_shape = orig_shape
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def shape(self):
|
| 38 |
+
"""Return the shape of the data tensor."""
|
| 39 |
+
return self.data.shape
|
| 40 |
+
|
| 41 |
+
def cpu(self):
|
| 42 |
+
"""Return a copy of the tensor on CPU memory."""
|
| 43 |
+
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
|
| 44 |
+
|
| 45 |
+
def numpy(self):
|
| 46 |
+
"""Return a copy of the tensor as a numpy array."""
|
| 47 |
+
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
|
| 48 |
+
|
| 49 |
+
def cuda(self):
|
| 50 |
+
"""Return a copy of the tensor on GPU memory."""
|
| 51 |
+
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
|
| 52 |
+
|
| 53 |
+
def to(self, *args, **kwargs):
|
| 54 |
+
"""Return a copy of the tensor with the specified device and dtype."""
|
| 55 |
+
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
|
| 56 |
+
|
| 57 |
+
def __len__(self): # override len(results)
|
| 58 |
+
"""Return the length of the data tensor."""
|
| 59 |
+
return len(self.data)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, idx):
|
| 62 |
+
"""Return a BaseTensor with the specified index of the data tensor."""
|
| 63 |
+
return self.__class__(self.data[idx], self.orig_shape)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Results(SimpleClass):
|
| 67 |
+
"""
|
| 68 |
+
A class for storing and manipulating inference results.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
orig_img (numpy.ndarray): The original image as a numpy array.
|
| 72 |
+
path (str): The path to the image file.
|
| 73 |
+
names (dict): A dictionary of class names.
|
| 74 |
+
boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
|
| 75 |
+
masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
|
| 76 |
+
probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
|
| 77 |
+
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
| 78 |
+
|
| 79 |
+
Attributes:
|
| 80 |
+
orig_img (numpy.ndarray): The original image as a numpy array.
|
| 81 |
+
orig_shape (tuple): The original image shape in (height, width) format.
|
| 82 |
+
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
| 83 |
+
masks (Masks, optional): A Masks object containing the detection masks.
|
| 84 |
+
probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
|
| 85 |
+
keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
|
| 86 |
+
speed (dict): A dictionary of preprocess, inference, and postprocess speeds in milliseconds per image.
|
| 87 |
+
names (dict): A dictionary of class names.
|
| 88 |
+
path (str): The path to the image file.
|
| 89 |
+
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None) -> None:
|
| 93 |
+
"""Initialize the Results class."""
|
| 94 |
+
self.orig_img = orig_img
|
| 95 |
+
self.orig_shape = orig_img.shape[:2]
|
| 96 |
+
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
| 97 |
+
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
| 98 |
+
self.probs = Probs(probs) if probs is not None else None
|
| 99 |
+
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
| 100 |
+
self.obb = OBB(obb, self.orig_shape) if obb is not None else None
|
| 101 |
+
self.speed = {"preprocess": None, "inference": None, "postprocess": None} # milliseconds per image
|
| 102 |
+
self.names = names
|
| 103 |
+
self.path = path
|
| 104 |
+
self.save_dir = None
|
| 105 |
+
self._keys = "boxes", "masks", "probs", "keypoints", "obb"
|
| 106 |
+
|
| 107 |
+
def __getitem__(self, idx):
|
| 108 |
+
"""Return a Results object for the specified index."""
|
| 109 |
+
return self._apply("__getitem__", idx)
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
"""Return the number of detections in the Results object."""
|
| 113 |
+
for k in self._keys:
|
| 114 |
+
v = getattr(self, k)
|
| 115 |
+
if v is not None:
|
| 116 |
+
return len(v)
|
| 117 |
+
|
| 118 |
+
def update(self, boxes=None, masks=None, probs=None, obb=None):
|
| 119 |
+
"""Update the boxes, masks, and probs attributes of the Results object."""
|
| 120 |
+
if boxes is not None:
|
| 121 |
+
self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
|
| 122 |
+
if masks is not None:
|
| 123 |
+
self.masks = Masks(masks, self.orig_shape)
|
| 124 |
+
if probs is not None:
|
| 125 |
+
self.probs = probs
|
| 126 |
+
if obb is not None:
|
| 127 |
+
self.obb = OBB(obb, self.orig_shape)
|
| 128 |
+
|
| 129 |
+
def _apply(self, fn, *args, **kwargs):
|
| 130 |
+
"""
|
| 131 |
+
Applies a function to all non-empty attributes and returns a new Results object with modified attributes. This
|
| 132 |
+
function is internally called by methods like .to(), .cuda(), .cpu(), etc.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
fn (str): The name of the function to apply.
|
| 136 |
+
*args: Variable length argument list to pass to the function.
|
| 137 |
+
**kwargs: Arbitrary keyword arguments to pass to the function.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Results: A new Results object with attributes modified by the applied function.
|
| 141 |
+
"""
|
| 142 |
+
r = self.new()
|
| 143 |
+
for k in self._keys:
|
| 144 |
+
v = getattr(self, k)
|
| 145 |
+
if v is not None:
|
| 146 |
+
setattr(r, k, getattr(v, fn)(*args, **kwargs))
|
| 147 |
+
return r
|
| 148 |
+
|
| 149 |
+
def cpu(self):
|
| 150 |
+
"""Return a copy of the Results object with all tensors on CPU memory."""
|
| 151 |
+
return self._apply("cpu")
|
| 152 |
+
|
| 153 |
+
def numpy(self):
|
| 154 |
+
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
| 155 |
+
return self._apply("numpy")
|
| 156 |
+
|
| 157 |
+
def cuda(self):
|
| 158 |
+
"""Return a copy of the Results object with all tensors on GPU memory."""
|
| 159 |
+
return self._apply("cuda")
|
| 160 |
+
|
| 161 |
+
def to(self, *args, **kwargs):
|
| 162 |
+
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
| 163 |
+
return self._apply("to", *args, **kwargs)
|
| 164 |
+
|
| 165 |
+
def new(self):
|
| 166 |
+
"""Return a new Results object with the same image, path, and names."""
|
| 167 |
+
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
| 168 |
+
|
| 169 |
+
def plot(
|
| 170 |
+
self,
|
| 171 |
+
conf=True,
|
| 172 |
+
line_width=None,
|
| 173 |
+
font_size=None,
|
| 174 |
+
font="Arial.ttf",
|
| 175 |
+
pil=False,
|
| 176 |
+
img=None,
|
| 177 |
+
im_gpu=None,
|
| 178 |
+
kpt_radius=5,
|
| 179 |
+
kpt_line=True,
|
| 180 |
+
labels=True,
|
| 181 |
+
boxes=True,
|
| 182 |
+
masks=True,
|
| 183 |
+
probs=True,
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
conf (bool): Whether to plot the detection confidence score.
|
| 190 |
+
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
| 191 |
+
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
| 192 |
+
font (str): The font to use for the text.
|
| 193 |
+
pil (bool): Whether to return the image as a PIL Image.
|
| 194 |
+
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
| 195 |
+
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
|
| 196 |
+
kpt_radius (int, optional): Radius of the drawn keypoints. Default is 5.
|
| 197 |
+
kpt_line (bool): Whether to draw lines connecting keypoints.
|
| 198 |
+
labels (bool): Whether to plot the label of bounding boxes.
|
| 199 |
+
boxes (bool): Whether to plot the bounding boxes.
|
| 200 |
+
masks (bool): Whether to plot the masks.
|
| 201 |
+
probs (bool): Whether to plot classification probability
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
(numpy.ndarray): A numpy array of the annotated image.
|
| 205 |
+
|
| 206 |
+
Example:
|
| 207 |
+
```python
|
| 208 |
+
from PIL import Image
|
| 209 |
+
from ultralytics import YOLO
|
| 210 |
+
|
| 211 |
+
model = YOLO('yolov8n.pt')
|
| 212 |
+
results = model('bus.jpg') # results list
|
| 213 |
+
for r in results:
|
| 214 |
+
im_array = r.plot() # plot a BGR numpy array of predictions
|
| 215 |
+
im = Image.fromarray(im_array[..., ::-1]) # RGB PIL image
|
| 216 |
+
im.show() # show image
|
| 217 |
+
im.save('results.jpg') # save image
|
| 218 |
+
```
|
| 219 |
+
"""
|
| 220 |
+
if img is None and isinstance(self.orig_img, torch.Tensor):
|
| 221 |
+
img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()
|
| 222 |
+
|
| 223 |
+
names = self.names
|
| 224 |
+
is_obb = self.obb is not None
|
| 225 |
+
pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
|
| 226 |
+
pred_masks, show_masks = self.masks, masks
|
| 227 |
+
pred_probs, show_probs = self.probs, probs
|
| 228 |
+
annotator = Annotator(
|
| 229 |
+
deepcopy(self.orig_img if img is None else img),
|
| 230 |
+
line_width,
|
| 231 |
+
font_size,
|
| 232 |
+
font,
|
| 233 |
+
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
| 234 |
+
example=names,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Plot Segment results
|
| 238 |
+
if pred_masks and show_masks:
|
| 239 |
+
if im_gpu is None:
|
| 240 |
+
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
| 241 |
+
im_gpu = (
|
| 242 |
+
torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device)
|
| 243 |
+
.permute(2, 0, 1)
|
| 244 |
+
.flip(0)
|
| 245 |
+
.contiguous()
|
| 246 |
+
/ 255
|
| 247 |
+
)
|
| 248 |
+
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
| 249 |
+
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
| 250 |
+
|
| 251 |
+
# Plot Detect results
|
| 252 |
+
if pred_boxes is not None and show_boxes:
|
| 253 |
+
for d in reversed(pred_boxes):
|
| 254 |
+
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
| 255 |
+
name = ("" if id is None else f"id:{id} ") + names[c]
|
| 256 |
+
label = (f"{name} {conf:.2f}" if conf else name) if labels else None
|
| 257 |
+
box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
|
| 258 |
+
annotator.box_label(box, label, color=colors(c, True), rotated=is_obb)
|
| 259 |
+
|
| 260 |
+
# Plot Classify results
|
| 261 |
+
if pred_probs is not None and show_probs:
|
| 262 |
+
text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5)
|
| 263 |
+
x = round(self.orig_shape[0] * 0.03)
|
| 264 |
+
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
| 265 |
+
|
| 266 |
+
# Plot Pose results
|
| 267 |
+
if self.keypoints is not None:
|
| 268 |
+
for k in reversed(self.keypoints.data):
|
| 269 |
+
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line)
|
| 270 |
+
|
| 271 |
+
return annotator.result()
|
| 272 |
+
|
| 273 |
+
def verbose(self):
|
| 274 |
+
"""Return log string for each task."""
|
| 275 |
+
log_string = ""
|
| 276 |
+
probs = self.probs
|
| 277 |
+
boxes = self.boxes
|
| 278 |
+
if len(self) == 0:
|
| 279 |
+
return log_string if probs is not None else f"{log_string}(no detections), "
|
| 280 |
+
if probs is not None:
|
| 281 |
+
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
| 282 |
+
if boxes:
|
| 283 |
+
for c in boxes.cls.unique():
|
| 284 |
+
n = (boxes.cls == c).sum() # detections per class
|
| 285 |
+
log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "
|
| 286 |
+
return log_string
|
| 287 |
+
|
| 288 |
+
def save_txt(self, txt_file, save_conf=False):
|
| 289 |
+
"""
|
| 290 |
+
Save predictions into txt file.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
txt_file (str): txt file path.
|
| 294 |
+
save_conf (bool): save confidence score or not.
|
| 295 |
+
"""
|
| 296 |
+
is_obb = self.obb is not None
|
| 297 |
+
boxes = self.obb if is_obb else self.boxes
|
| 298 |
+
masks = self.masks
|
| 299 |
+
probs = self.probs
|
| 300 |
+
kpts = self.keypoints
|
| 301 |
+
texts = []
|
| 302 |
+
if probs is not None:
|
| 303 |
+
# Classify
|
| 304 |
+
[texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5]
|
| 305 |
+
elif boxes:
|
| 306 |
+
# Detect/segment/pose
|
| 307 |
+
for j, d in enumerate(boxes):
|
| 308 |
+
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
| 309 |
+
line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1)))
|
| 310 |
+
if masks:
|
| 311 |
+
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
| 312 |
+
line = (c, *seg)
|
| 313 |
+
if kpts is not None:
|
| 314 |
+
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
|
| 315 |
+
line += (*kpt.reshape(-1).tolist(),)
|
| 316 |
+
line += (conf,) * save_conf + (() if id is None else (id,))
|
| 317 |
+
texts.append(("%g " * len(line)).rstrip() % line)
|
| 318 |
+
|
| 319 |
+
if texts:
|
| 320 |
+
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
|
| 321 |
+
with open(txt_file, "a") as f:
|
| 322 |
+
f.writelines(text + "\n" for text in texts)
|
| 323 |
+
|
| 324 |
+
def save_crop(self, save_dir, file_name=Path("im.jpg")):
|
| 325 |
+
"""
|
| 326 |
+
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
save_dir (str | pathlib.Path): Save path.
|
| 330 |
+
file_name (str | pathlib.Path): File name.
|
| 331 |
+
"""
|
| 332 |
+
if self.probs is not None:
|
| 333 |
+
LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.")
|
| 334 |
+
return
|
| 335 |
+
if self.obb is not None:
|
| 336 |
+
LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.")
|
| 337 |
+
return
|
| 338 |
+
for d in self.boxes:
|
| 339 |
+
save_one_box(
|
| 340 |
+
d.xyxy,
|
| 341 |
+
self.orig_img.copy(),
|
| 342 |
+
file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg",
|
| 343 |
+
BGR=True,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def tojson(self, normalize=False):
|
| 347 |
+
"""Convert the object to JSON format."""
|
| 348 |
+
if self.probs is not None:
|
| 349 |
+
LOGGER.warning("Warning: Classify task do not support `tojson` yet.")
|
| 350 |
+
return
|
| 351 |
+
|
| 352 |
+
import json
|
| 353 |
+
|
| 354 |
+
# Create list of detection dictionaries
|
| 355 |
+
results = []
|
| 356 |
+
data = self.boxes.data.cpu().tolist()
|
| 357 |
+
h, w = self.orig_shape if normalize else (1, 1)
|
| 358 |
+
for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id
|
| 359 |
+
box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h}
|
| 360 |
+
conf = row[-2]
|
| 361 |
+
class_id = int(row[-1])
|
| 362 |
+
name = self.names[class_id]
|
| 363 |
+
result = {"name": name, "class": class_id, "confidence": conf, "box": box}
|
| 364 |
+
if self.boxes.is_track:
|
| 365 |
+
result["track_id"] = int(row[-3]) # track ID
|
| 366 |
+
if self.masks:
|
| 367 |
+
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
| 368 |
+
result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()}
|
| 369 |
+
if self.keypoints is not None:
|
| 370 |
+
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
| 371 |
+
result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
|
| 372 |
+
results.append(result)
|
| 373 |
+
|
| 374 |
+
# Convert detections to JSON
|
| 375 |
+
return json.dumps(results, indent=2)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class Boxes(BaseTensor):
|
| 379 |
+
"""
|
| 380 |
+
A class for storing and manipulating detection boxes.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
| 384 |
+
with shape (num_boxes, 6) or (num_boxes, 7). The last two columns contain confidence and class values.
|
| 385 |
+
If present, the third last column contains track IDs.
|
| 386 |
+
orig_shape (tuple): Original image size, in the format (height, width).
|
| 387 |
+
|
| 388 |
+
Attributes:
|
| 389 |
+
xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
|
| 390 |
+
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
|
| 391 |
+
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
|
| 392 |
+
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
|
| 393 |
+
xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
|
| 394 |
+
xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
| 395 |
+
xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
|
| 396 |
+
data (torch.Tensor): The raw bboxes tensor (alias for `boxes`).
|
| 397 |
+
|
| 398 |
+
Methods:
|
| 399 |
+
cpu(): Move the object to CPU memory.
|
| 400 |
+
numpy(): Convert the object to a numpy array.
|
| 401 |
+
cuda(): Move the object to CUDA memory.
|
| 402 |
+
to(*args, **kwargs): Move the object to the specified device.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
def __init__(self, boxes, orig_shape) -> None:
|
| 406 |
+
"""Initialize the Boxes class."""
|
| 407 |
+
if boxes.ndim == 1:
|
| 408 |
+
boxes = boxes[None, :]
|
| 409 |
+
n = boxes.shape[-1]
|
| 410 |
+
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
| 411 |
+
super().__init__(boxes, orig_shape)
|
| 412 |
+
self.is_track = n == 7
|
| 413 |
+
self.orig_shape = orig_shape
|
| 414 |
+
|
| 415 |
+
@property
|
| 416 |
+
def xyxy(self):
|
| 417 |
+
"""Return the boxes in xyxy format."""
|
| 418 |
+
return self.data[:, :4]
|
| 419 |
+
|
| 420 |
+
@property
|
| 421 |
+
def conf(self):
|
| 422 |
+
"""Return the confidence values of the boxes."""
|
| 423 |
+
return self.data[:, -2]
|
| 424 |
+
|
| 425 |
+
@property
|
| 426 |
+
def cls(self):
|
| 427 |
+
"""Return the class values of the boxes."""
|
| 428 |
+
return self.data[:, -1]
|
| 429 |
+
|
| 430 |
+
@property
|
| 431 |
+
def id(self):
|
| 432 |
+
"""Return the track IDs of the boxes (if available)."""
|
| 433 |
+
return self.data[:, -3] if self.is_track else None
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
| 437 |
+
def xywh(self):
|
| 438 |
+
"""Return the boxes in xywh format."""
|
| 439 |
+
return ops.xyxy2xywh(self.xyxy)
|
| 440 |
+
|
| 441 |
+
@property
|
| 442 |
+
@lru_cache(maxsize=2)
|
| 443 |
+
def xyxyn(self):
|
| 444 |
+
"""Return the boxes in xyxy format normalized by original image size."""
|
| 445 |
+
xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
|
| 446 |
+
xyxy[..., [0, 2]] /= self.orig_shape[1]
|
| 447 |
+
xyxy[..., [1, 3]] /= self.orig_shape[0]
|
| 448 |
+
return xyxy
|
| 449 |
+
|
| 450 |
+
@property
|
| 451 |
+
@lru_cache(maxsize=2)
|
| 452 |
+
def xywhn(self):
|
| 453 |
+
"""Return the boxes in xywh format normalized by original image size."""
|
| 454 |
+
xywh = ops.xyxy2xywh(self.xyxy)
|
| 455 |
+
xywh[..., [0, 2]] /= self.orig_shape[1]
|
| 456 |
+
xywh[..., [1, 3]] /= self.orig_shape[0]
|
| 457 |
+
return xywh
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class Masks(BaseTensor):
|
| 461 |
+
"""
|
| 462 |
+
A class for storing and manipulating detection masks.
|
| 463 |
+
|
| 464 |
+
Attributes:
|
| 465 |
+
xy (list): A list of segments in pixel coordinates.
|
| 466 |
+
xyn (list): A list of normalized segments.
|
| 467 |
+
|
| 468 |
+
Methods:
|
| 469 |
+
cpu(): Returns the masks tensor on CPU memory.
|
| 470 |
+
numpy(): Returns the masks tensor as a numpy array.
|
| 471 |
+
cuda(): Returns the masks tensor on GPU memory.
|
| 472 |
+
to(device, dtype): Returns the masks tensor with the specified device and dtype.
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
def __init__(self, masks, orig_shape) -> None:
|
| 476 |
+
"""Initialize the Masks class with the given masks tensor and original image shape."""
|
| 477 |
+
if masks.ndim == 2:
|
| 478 |
+
masks = masks[None, :]
|
| 479 |
+
super().__init__(masks, orig_shape)
|
| 480 |
+
|
| 481 |
+
@property
|
| 482 |
+
@lru_cache(maxsize=1)
|
| 483 |
+
def xyn(self):
|
| 484 |
+
"""Return normalized segments."""
|
| 485 |
+
return [
|
| 486 |
+
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
| 487 |
+
for x in ops.masks2segments(self.data)
|
| 488 |
+
]
|
| 489 |
+
|
| 490 |
+
@property
|
| 491 |
+
@lru_cache(maxsize=1)
|
| 492 |
+
def xy(self):
|
| 493 |
+
"""Return segments in pixel coordinates."""
|
| 494 |
+
return [
|
| 495 |
+
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
| 496 |
+
for x in ops.masks2segments(self.data)
|
| 497 |
+
]
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class Keypoints(BaseTensor):
|
| 501 |
+
"""
|
| 502 |
+
A class for storing and manipulating detection keypoints.
|
| 503 |
+
|
| 504 |
+
Attributes:
|
| 505 |
+
xy (torch.Tensor): A collection of keypoints containing x, y coordinates for each detection.
|
| 506 |
+
xyn (torch.Tensor): A normalized version of xy with coordinates in the range [0, 1].
|
| 507 |
+
conf (torch.Tensor): Confidence values associated with keypoints if available, otherwise None.
|
| 508 |
+
|
| 509 |
+
Methods:
|
| 510 |
+
cpu(): Returns a copy of the keypoints tensor on CPU memory.
|
| 511 |
+
numpy(): Returns a copy of the keypoints tensor as a numpy array.
|
| 512 |
+
cuda(): Returns a copy of the keypoints tensor on GPU memory.
|
| 513 |
+
to(device, dtype): Returns a copy of the keypoints tensor with the specified device and dtype.
|
| 514 |
+
"""
|
| 515 |
+
|
| 516 |
+
@smart_inference_mode() # avoid keypoints < conf in-place error
|
| 517 |
+
def __init__(self, keypoints, orig_shape) -> None:
|
| 518 |
+
"""Initializes the Keypoints object with detection keypoints and original image size."""
|
| 519 |
+
if keypoints.ndim == 2:
|
| 520 |
+
keypoints = keypoints[None, :]
|
| 521 |
+
if keypoints.shape[2] == 3: # x, y, conf
|
| 522 |
+
mask = keypoints[..., 2] < 0.5 # points with conf < 0.5 (not visible)
|
| 523 |
+
keypoints[..., :2][mask] = 0
|
| 524 |
+
super().__init__(keypoints, orig_shape)
|
| 525 |
+
self.has_visible = self.data.shape[-1] == 3
|
| 526 |
+
|
| 527 |
+
@property
|
| 528 |
+
@lru_cache(maxsize=1)
|
| 529 |
+
def xy(self):
|
| 530 |
+
"""Returns x, y coordinates of keypoints."""
|
| 531 |
+
return self.data[..., :2]
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
@lru_cache(maxsize=1)
|
| 535 |
+
def xyn(self):
|
| 536 |
+
"""Returns normalized x, y coordinates of keypoints."""
|
| 537 |
+
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
|
| 538 |
+
xy[..., 0] /= self.orig_shape[1]
|
| 539 |
+
xy[..., 1] /= self.orig_shape[0]
|
| 540 |
+
return xy
|
| 541 |
+
|
| 542 |
+
@property
|
| 543 |
+
@lru_cache(maxsize=1)
|
| 544 |
+
def conf(self):
|
| 545 |
+
"""Returns confidence values of keypoints if available, else None."""
|
| 546 |
+
return self.data[..., 2] if self.has_visible else None
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class Probs(BaseTensor):
|
| 550 |
+
"""
|
| 551 |
+
A class for storing and manipulating classification predictions.
|
| 552 |
+
|
| 553 |
+
Attributes:
|
| 554 |
+
top1 (int): Index of the top 1 class.
|
| 555 |
+
top5 (list[int]): Indices of the top 5 classes.
|
| 556 |
+
top1conf (torch.Tensor): Confidence of the top 1 class.
|
| 557 |
+
top5conf (torch.Tensor): Confidences of the top 5 classes.
|
| 558 |
+
|
| 559 |
+
Methods:
|
| 560 |
+
cpu(): Returns a copy of the probs tensor on CPU memory.
|
| 561 |
+
numpy(): Returns a copy of the probs tensor as a numpy array.
|
| 562 |
+
cuda(): Returns a copy of the probs tensor on GPU memory.
|
| 563 |
+
to(): Returns a copy of the probs tensor with the specified device and dtype.
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
def __init__(self, probs, orig_shape=None) -> None:
|
| 567 |
+
"""Initialize the Probs class with classification probabilities and optional original shape of the image."""
|
| 568 |
+
super().__init__(probs, orig_shape)
|
| 569 |
+
|
| 570 |
+
@property
|
| 571 |
+
@lru_cache(maxsize=1)
|
| 572 |
+
def top1(self):
|
| 573 |
+
"""Return the index of top 1."""
|
| 574 |
+
return int(self.data.argmax())
|
| 575 |
+
|
| 576 |
+
@property
|
| 577 |
+
@lru_cache(maxsize=1)
|
| 578 |
+
def top5(self):
|
| 579 |
+
"""Return the indices of top 5."""
|
| 580 |
+
return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
|
| 581 |
+
|
| 582 |
+
@property
|
| 583 |
+
@lru_cache(maxsize=1)
|
| 584 |
+
def top1conf(self):
|
| 585 |
+
"""Return the confidence of top 1."""
|
| 586 |
+
return self.data[self.top1]
|
| 587 |
+
|
| 588 |
+
@property
|
| 589 |
+
@lru_cache(maxsize=1)
|
| 590 |
+
def top5conf(self):
|
| 591 |
+
"""Return the confidences of top 5."""
|
| 592 |
+
return self.data[self.top5]
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
class OBB(BaseTensor):
|
| 596 |
+
"""
|
| 597 |
+
A class for storing and manipulating Oriented Bounding Boxes (OBB).
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
| 601 |
+
with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values.
|
| 602 |
+
If present, the third last column contains track IDs, and the fifth column from the left contains rotation.
|
| 603 |
+
orig_shape (tuple): Original image size, in the format (height, width).
|
| 604 |
+
|
| 605 |
+
Attributes:
|
| 606 |
+
xywhr (torch.Tensor | numpy.ndarray): The boxes in [x_center, y_center, width, height, rotation] format.
|
| 607 |
+
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
|
| 608 |
+
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
|
| 609 |
+
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
|
| 610 |
+
xyxyxyxyn (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format normalized by original image size.
|
| 611 |
+
xyxyxyxy (torch.Tensor | numpy.ndarray): The rotated boxes in xyxyxyxy format.
|
| 612 |
+
xyxy (torch.Tensor | numpy.ndarray): The horizontal boxes in xyxyxyxy format.
|
| 613 |
+
data (torch.Tensor): The raw OBB tensor (alias for `boxes`).
|
| 614 |
+
|
| 615 |
+
Methods:
|
| 616 |
+
cpu(): Move the object to CPU memory.
|
| 617 |
+
numpy(): Convert the object to a numpy array.
|
| 618 |
+
cuda(): Move the object to CUDA memory.
|
| 619 |
+
to(*args, **kwargs): Move the object to the specified device.
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
def __init__(self, boxes, orig_shape) -> None:
|
| 623 |
+
"""Initialize the Boxes class."""
|
| 624 |
+
if boxes.ndim == 1:
|
| 625 |
+
boxes = boxes[None, :]
|
| 626 |
+
n = boxes.shape[-1]
|
| 627 |
+
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
| 628 |
+
super().__init__(boxes, orig_shape)
|
| 629 |
+
self.is_track = n == 8
|
| 630 |
+
self.orig_shape = orig_shape
|
| 631 |
+
|
| 632 |
+
@property
|
| 633 |
+
def xywhr(self):
|
| 634 |
+
"""Return the rotated boxes in xywhr format."""
|
| 635 |
+
return self.data[:, :5]
|
| 636 |
+
|
| 637 |
+
@property
|
| 638 |
+
def conf(self):
|
| 639 |
+
"""Return the confidence values of the boxes."""
|
| 640 |
+
return self.data[:, -2]
|
| 641 |
+
|
| 642 |
+
@property
|
| 643 |
+
def cls(self):
|
| 644 |
+
"""Return the class values of the boxes."""
|
| 645 |
+
return self.data[:, -1]
|
| 646 |
+
|
| 647 |
+
@property
|
| 648 |
+
def id(self):
|
| 649 |
+
"""Return the track IDs of the boxes (if available)."""
|
| 650 |
+
return self.data[:, -3] if self.is_track else None
|
| 651 |
+
|
| 652 |
+
@property
|
| 653 |
+
@lru_cache(maxsize=2)
|
| 654 |
+
def xyxyxyxy(self):
|
| 655 |
+
"""Return the boxes in xyxyxyxy format, (N, 4, 2)."""
|
| 656 |
+
return ops.xywhr2xyxyxyxy(self.xywhr)
|
| 657 |
+
|
| 658 |
+
@property
|
| 659 |
+
@lru_cache(maxsize=2)
|
| 660 |
+
def xyxyxyxyn(self):
|
| 661 |
+
"""Return the boxes in xyxyxyxy format, (N, 4, 2)."""
|
| 662 |
+
xyxyxyxyn = self.xyxyxyxy.clone() if isinstance(self.xyxyxyxy, torch.Tensor) else np.copy(self.xyxyxyxy)
|
| 663 |
+
xyxyxyxyn[..., 0] /= self.orig_shape[1]
|
| 664 |
+
xyxyxyxyn[..., 1] /= self.orig_shape[1]
|
| 665 |
+
return xyxyxyxyn
|
| 666 |
+
|
| 667 |
+
@property
|
| 668 |
+
@lru_cache(maxsize=2)
|
| 669 |
+
def xyxy(self):
|
| 670 |
+
"""
|
| 671 |
+
Return the horizontal boxes in xyxy format, (N, 4).
|
| 672 |
+
|
| 673 |
+
Accepts both torch and numpy boxes.
|
| 674 |
+
"""
|
| 675 |
+
x1 = self.xyxyxyxy[..., 0].min(1).values
|
| 676 |
+
x2 = self.xyxyxyxy[..., 0].max(1).values
|
| 677 |
+
y1 = self.xyxyxyxy[..., 1].min(1).values
|
| 678 |
+
y2 = self.xyxyxyxy[..., 1].max(1).values
|
| 679 |
+
xyxy = [x1, y1, x2, y2]
|
| 680 |
+
return np.stack(xyxy, axis=-1) if isinstance(self.data, np.ndarray) else torch.stack(xyxy, dim=-1)
|
yolov8_model/ultralytics/engine/trainer.py
ADDED
|
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
"""
|
| 3 |
+
Train a model on a dataset.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import subprocess
|
| 12 |
+
import time
|
| 13 |
+
import warnings
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
from datetime import datetime, timedelta
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from torch import distributed as dist
|
| 21 |
+
from torch import nn, optim
|
| 22 |
+
|
| 23 |
+
from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
|
| 24 |
+
from yolov8_model.ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
| 25 |
+
from yolov8_model.ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
| 26 |
+
from yolov8_model.ultralytics.utils import (
|
| 27 |
+
DEFAULT_CFG,
|
| 28 |
+
LOGGER,
|
| 29 |
+
RANK,
|
| 30 |
+
TQDM,
|
| 31 |
+
__version__,
|
| 32 |
+
callbacks,
|
| 33 |
+
clean_url,
|
| 34 |
+
colorstr,
|
| 35 |
+
emojis,
|
| 36 |
+
yaml_save,
|
| 37 |
+
)
|
| 38 |
+
from yolov8_model.ultralytics.utils.autobatch import check_train_batch_size
|
| 39 |
+
from yolov8_model.ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
| 40 |
+
from yolov8_model.ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
| 41 |
+
from yolov8_model.ultralytics.utils.files import get_latest_run
|
| 42 |
+
from yolov8_model.ultralytics.utils.torch_utils import (
|
| 43 |
+
EarlyStopping,
|
| 44 |
+
ModelEMA,
|
| 45 |
+
de_parallel,
|
| 46 |
+
init_seeds,
|
| 47 |
+
one_cycle,
|
| 48 |
+
select_device,
|
| 49 |
+
strip_optimizer,
|
| 50 |
+
)
|
| 51 |
+
from yolov8_model.ultralytics.nn.extra_modules.kernel_warehouse import get_temperature
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BaseTrainer:
|
| 55 |
+
"""
|
| 56 |
+
BaseTrainer.
|
| 57 |
+
|
| 58 |
+
A base class for creating trainers.
|
| 59 |
+
|
| 60 |
+
Attributes:
|
| 61 |
+
args (SimpleNamespace): Configuration for the trainer.
|
| 62 |
+
validator (BaseValidator): Validator instance.
|
| 63 |
+
model (nn.Module): Model instance.
|
| 64 |
+
callbacks (defaultdict): Dictionary of callbacks.
|
| 65 |
+
save_dir (Path): Directory to save results.
|
| 66 |
+
wdir (Path): Directory to save weights.
|
| 67 |
+
last (Path): Path to the last checkpoint.
|
| 68 |
+
best (Path): Path to the best checkpoint.
|
| 69 |
+
save_period (int): Save checkpoint every x epochs (disabled if < 1).
|
| 70 |
+
batch_size (int): Batch size for training.
|
| 71 |
+
epochs (int): Number of epochs to train for.
|
| 72 |
+
start_epoch (int): Starting epoch for training.
|
| 73 |
+
device (torch.device): Device to use for training.
|
| 74 |
+
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
| 75 |
+
scaler (amp.GradScaler): Gradient scaler for AMP.
|
| 76 |
+
data (str): Path to data.
|
| 77 |
+
trainset (torch.utils.data.Dataset): Training dataset.
|
| 78 |
+
testset (torch.utils.data.Dataset): Testing dataset.
|
| 79 |
+
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
| 80 |
+
resume (bool): Resume training from a checkpoint.
|
| 81 |
+
lf (nn.Module): Loss function.
|
| 82 |
+
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
| 83 |
+
best_fitness (float): The best fitness value achieved.
|
| 84 |
+
fitness (float): Current fitness value.
|
| 85 |
+
loss (float): Current loss value.
|
| 86 |
+
tloss (float): Total loss value.
|
| 87 |
+
loss_names (list): List of loss names.
|
| 88 |
+
csv (Path): Path to results CSV file.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
| 92 |
+
"""
|
| 93 |
+
Initializes the BaseTrainer class.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
| 97 |
+
overrides (dict, optional): Configuration overrides. Defaults to None.
|
| 98 |
+
"""
|
| 99 |
+
self.args = get_cfg(cfg, overrides)
|
| 100 |
+
self.check_resume(overrides)
|
| 101 |
+
self.device = select_device(self.args.device, self.args.batch)
|
| 102 |
+
self.validator = None
|
| 103 |
+
self.metrics = None
|
| 104 |
+
self.plots = {}
|
| 105 |
+
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
| 106 |
+
|
| 107 |
+
# Dirs
|
| 108 |
+
self.save_dir = get_save_dir(self.args)
|
| 109 |
+
self.args.name = self.save_dir.name # update name for loggers
|
| 110 |
+
self.wdir = self.save_dir / "weights" # weights dir
|
| 111 |
+
if RANK in (-1, 0):
|
| 112 |
+
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
| 113 |
+
self.args.save_dir = str(self.save_dir)
|
| 114 |
+
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
| 115 |
+
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
| 116 |
+
self.save_period = self.args.save_period
|
| 117 |
+
|
| 118 |
+
self.batch_size = self.args.batch
|
| 119 |
+
self.epochs = self.args.epochs
|
| 120 |
+
self.start_epoch = 0
|
| 121 |
+
if RANK == -1:
|
| 122 |
+
print_args(vars(self.args))
|
| 123 |
+
|
| 124 |
+
# Device
|
| 125 |
+
if self.device.type in ("cpu", "mps"):
|
| 126 |
+
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
| 127 |
+
|
| 128 |
+
# Model and Dataset
|
| 129 |
+
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
| 130 |
+
try:
|
| 131 |
+
if self.args.task == "classify":
|
| 132 |
+
self.data = check_cls_dataset(self.args.data)
|
| 133 |
+
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ("detect", "segment", "pose"):
|
| 134 |
+
self.data = check_det_dataset(self.args.data)
|
| 135 |
+
if "yaml_file" in self.data:
|
| 136 |
+
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
| 137 |
+
except Exception as e:
|
| 138 |
+
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
| 139 |
+
|
| 140 |
+
self.trainset, self.testset = self.get_dataset(self.data)
|
| 141 |
+
self.ema = None
|
| 142 |
+
|
| 143 |
+
# Optimization utils init
|
| 144 |
+
self.lf = None
|
| 145 |
+
self.scheduler = None
|
| 146 |
+
|
| 147 |
+
# Epoch level metrics
|
| 148 |
+
self.best_fitness = None
|
| 149 |
+
self.fitness = None
|
| 150 |
+
self.loss = None
|
| 151 |
+
self.tloss = None
|
| 152 |
+
self.loss_names = ["Loss"]
|
| 153 |
+
self.csv = self.save_dir / "results.csv"
|
| 154 |
+
self.plot_idx = [0, 1, 2]
|
| 155 |
+
|
| 156 |
+
# Callbacks
|
| 157 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
| 158 |
+
if RANK in (-1, 0):
|
| 159 |
+
callbacks.add_integration_callbacks(self)
|
| 160 |
+
|
| 161 |
+
def add_callback(self, event: str, callback):
|
| 162 |
+
"""Appends the given callback."""
|
| 163 |
+
self.callbacks[event].append(callback)
|
| 164 |
+
|
| 165 |
+
def set_callback(self, event: str, callback):
|
| 166 |
+
"""Overrides the existing callbacks with the given callback."""
|
| 167 |
+
self.callbacks[event] = [callback]
|
| 168 |
+
|
| 169 |
+
def run_callbacks(self, event: str):
|
| 170 |
+
"""Run all existing callbacks associated with a particular event."""
|
| 171 |
+
for callback in self.callbacks.get(event, []):
|
| 172 |
+
callback(self)
|
| 173 |
+
|
| 174 |
+
def train(self):
|
| 175 |
+
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
| 176 |
+
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
| 177 |
+
world_size = len(self.args.device.split(","))
|
| 178 |
+
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
| 179 |
+
world_size = len(self.args.device)
|
| 180 |
+
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
| 181 |
+
world_size = 1 # default to device 0
|
| 182 |
+
else: # i.e. device='cpu' or 'mps'
|
| 183 |
+
world_size = 0
|
| 184 |
+
|
| 185 |
+
# Run subprocess if DDP training, else train normally
|
| 186 |
+
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
| 187 |
+
# Argument checks
|
| 188 |
+
if self.args.rect:
|
| 189 |
+
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
| 190 |
+
self.args.rect = False
|
| 191 |
+
if self.args.batch == -1:
|
| 192 |
+
LOGGER.warning(
|
| 193 |
+
"WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
| 194 |
+
"default 'batch=16'"
|
| 195 |
+
)
|
| 196 |
+
self.args.batch = 16
|
| 197 |
+
|
| 198 |
+
# Command
|
| 199 |
+
cmd, file = generate_ddp_command(world_size, self)
|
| 200 |
+
try:
|
| 201 |
+
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
|
| 202 |
+
subprocess.run(cmd, check=True)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
raise e
|
| 205 |
+
finally:
|
| 206 |
+
ddp_cleanup(self, str(file))
|
| 207 |
+
|
| 208 |
+
else:
|
| 209 |
+
self._do_train(world_size)
|
| 210 |
+
|
| 211 |
+
def _setup_scheduler(self):
|
| 212 |
+
"""Initialize training learning rate scheduler."""
|
| 213 |
+
if self.args.cos_lr:
|
| 214 |
+
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
| 215 |
+
else:
|
| 216 |
+
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
| 217 |
+
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
| 218 |
+
|
| 219 |
+
def _setup_ddp(self, world_size):
|
| 220 |
+
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
| 221 |
+
torch.cuda.set_device(RANK)
|
| 222 |
+
self.device = torch.device("cuda", RANK)
|
| 223 |
+
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
| 224 |
+
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
| 225 |
+
dist.init_process_group(
|
| 226 |
+
"nccl" if dist.is_nccl_available() else "gloo",
|
| 227 |
+
timeout=timedelta(seconds=10800), # 3 hours
|
| 228 |
+
rank=RANK,
|
| 229 |
+
world_size=world_size,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def _setup_train(self, world_size):
|
| 233 |
+
"""Builds dataloaders and optimizer on correct rank process."""
|
| 234 |
+
|
| 235 |
+
# Model
|
| 236 |
+
self.run_callbacks("on_pretrain_routine_start")
|
| 237 |
+
ckpt = self.setup_model()
|
| 238 |
+
self.model = self.model.to(self.device)
|
| 239 |
+
self.set_model_attributes()
|
| 240 |
+
|
| 241 |
+
# Freeze layers
|
| 242 |
+
freeze_list = (
|
| 243 |
+
self.args.freeze
|
| 244 |
+
if isinstance(self.args.freeze, list)
|
| 245 |
+
else range(self.args.freeze)
|
| 246 |
+
if isinstance(self.args.freeze, int)
|
| 247 |
+
else []
|
| 248 |
+
)
|
| 249 |
+
always_freeze_names = [".dfl"] # always freeze these layers
|
| 250 |
+
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
| 251 |
+
for k, v in self.model.named_parameters():
|
| 252 |
+
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
| 253 |
+
if any(x in k for x in freeze_layer_names):
|
| 254 |
+
LOGGER.info(f"Freezing layer '{k}'")
|
| 255 |
+
v.requires_grad = False
|
| 256 |
+
elif not v.requires_grad:
|
| 257 |
+
LOGGER.info(
|
| 258 |
+
f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
| 259 |
+
"See ultralytics.engine.trainer for customization of frozen layers."
|
| 260 |
+
)
|
| 261 |
+
v.requires_grad = True
|
| 262 |
+
|
| 263 |
+
# Check AMP
|
| 264 |
+
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
| 265 |
+
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
| 266 |
+
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
| 267 |
+
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
| 268 |
+
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
| 269 |
+
if RANK > -1 and world_size > 1: # DDP
|
| 270 |
+
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
| 271 |
+
self.amp = bool(self.amp) # as boolean
|
| 272 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
|
| 273 |
+
if world_size > 1:
|
| 274 |
+
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
| 275 |
+
|
| 276 |
+
# Check imgsz
|
| 277 |
+
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
| 278 |
+
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
| 279 |
+
self.stride = gs # for multi-scale training
|
| 280 |
+
|
| 281 |
+
# Batch size
|
| 282 |
+
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
|
| 283 |
+
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
| 284 |
+
|
| 285 |
+
# Dataloaders
|
| 286 |
+
batch_size = self.batch_size // max(world_size, 1)
|
| 287 |
+
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
| 288 |
+
if RANK in (-1, 0):
|
| 289 |
+
# NOTE: When training DOTA dataset, double batch size could get OOM cause some images got more than 2000 objects.
|
| 290 |
+
self.test_loader = self.get_dataloader(
|
| 291 |
+
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
| 292 |
+
)
|
| 293 |
+
self.validator = self.get_validator()
|
| 294 |
+
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
| 295 |
+
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
| 296 |
+
self.ema = ModelEMA(self.model)
|
| 297 |
+
if self.args.plots:
|
| 298 |
+
self.plot_training_labels()
|
| 299 |
+
|
| 300 |
+
# Optimizer
|
| 301 |
+
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
| 302 |
+
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
| 303 |
+
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
| 304 |
+
self.optimizer = self.build_optimizer(
|
| 305 |
+
model=self.model,
|
| 306 |
+
name=self.args.optimizer,
|
| 307 |
+
lr=self.args.lr0,
|
| 308 |
+
momentum=self.args.momentum,
|
| 309 |
+
decay=weight_decay,
|
| 310 |
+
iterations=iterations,
|
| 311 |
+
)
|
| 312 |
+
# Scheduler
|
| 313 |
+
self._setup_scheduler()
|
| 314 |
+
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
| 315 |
+
self.resume_training(ckpt)
|
| 316 |
+
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
| 317 |
+
self.run_callbacks("on_pretrain_routine_end")
|
| 318 |
+
|
| 319 |
+
def _do_train(self, world_size=1):
|
| 320 |
+
"""Train completed, evaluate and plot if specified by arguments."""
|
| 321 |
+
if world_size > 1:
|
| 322 |
+
self._setup_ddp(world_size)
|
| 323 |
+
self._setup_train(world_size)
|
| 324 |
+
|
| 325 |
+
nb = len(self.train_loader) # number of batches
|
| 326 |
+
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
| 327 |
+
last_opt_step = -1
|
| 328 |
+
self.epoch_time = None
|
| 329 |
+
self.epoch_time_start = time.time()
|
| 330 |
+
self.train_time_start = time.time()
|
| 331 |
+
self.run_callbacks("on_train_start")
|
| 332 |
+
LOGGER.info(
|
| 333 |
+
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
| 334 |
+
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
| 335 |
+
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
| 336 |
+
f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
| 337 |
+
)
|
| 338 |
+
if self.args.close_mosaic:
|
| 339 |
+
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
| 340 |
+
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
| 341 |
+
epoch = self.epochs # predefine for resume fully trained model edge cases
|
| 342 |
+
for epoch in range(self.start_epoch, self.epochs):
|
| 343 |
+
self.epoch = epoch
|
| 344 |
+
self.run_callbacks("on_train_epoch_start")
|
| 345 |
+
self.model.train()
|
| 346 |
+
if RANK != -1:
|
| 347 |
+
self.train_loader.sampler.set_epoch(epoch)
|
| 348 |
+
pbar = enumerate(self.train_loader)
|
| 349 |
+
# Update dataloader attributes (optional)
|
| 350 |
+
if epoch == (self.epochs - self.args.close_mosaic):
|
| 351 |
+
self._close_dataloader_mosaic()
|
| 352 |
+
self.train_loader.reset()
|
| 353 |
+
|
| 354 |
+
if RANK in (-1, 0):
|
| 355 |
+
LOGGER.info(self.progress_string())
|
| 356 |
+
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
| 357 |
+
self.tloss = None
|
| 358 |
+
self.optimizer.zero_grad()
|
| 359 |
+
for i, batch in pbar:
|
| 360 |
+
self.run_callbacks("on_train_batch_start")
|
| 361 |
+
# Warmup
|
| 362 |
+
ni = i + nb * epoch
|
| 363 |
+
if ni <= nw:
|
| 364 |
+
xi = [0, nw] # x interp
|
| 365 |
+
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
| 366 |
+
for j, x in enumerate(self.optimizer.param_groups):
|
| 367 |
+
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
| 368 |
+
x["lr"] = np.interp(
|
| 369 |
+
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
|
| 370 |
+
)
|
| 371 |
+
if "momentum" in x:
|
| 372 |
+
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
| 373 |
+
|
| 374 |
+
if hasattr(self.model, 'net_update_temperature'):
|
| 375 |
+
temp = get_temperature(i + 1, epoch, len(self.train_loader), temp_epoch=20, temp_init_value=1.0)
|
| 376 |
+
self.model.net_update_temperature(temp)
|
| 377 |
+
|
| 378 |
+
# Forward
|
| 379 |
+
with torch.cuda.amp.autocast(self.amp):
|
| 380 |
+
batch = self.preprocess_batch(batch)
|
| 381 |
+
self.loss, self.loss_items = self.model(batch)
|
| 382 |
+
if RANK != -1:
|
| 383 |
+
self.loss *= world_size
|
| 384 |
+
self.tloss = (
|
| 385 |
+
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Backward
|
| 389 |
+
self.scaler.scale(self.loss).backward()
|
| 390 |
+
|
| 391 |
+
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
| 392 |
+
if ni - last_opt_step >= self.accumulate:
|
| 393 |
+
self.optimizer_step()
|
| 394 |
+
last_opt_step = ni
|
| 395 |
+
|
| 396 |
+
# Timed stopping
|
| 397 |
+
if self.args.time:
|
| 398 |
+
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
|
| 399 |
+
if RANK != -1: # if DDP training
|
| 400 |
+
broadcast_list = [self.stop if RANK == 0 else None]
|
| 401 |
+
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
| 402 |
+
self.stop = broadcast_list[0]
|
| 403 |
+
if self.stop: # training time exceeded
|
| 404 |
+
break
|
| 405 |
+
|
| 406 |
+
# Log
|
| 407 |
+
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
| 408 |
+
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
| 409 |
+
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
| 410 |
+
if RANK in (-1, 0):
|
| 411 |
+
pbar.set_description(
|
| 412 |
+
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
| 413 |
+
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
| 414 |
+
)
|
| 415 |
+
self.run_callbacks("on_batch_end")
|
| 416 |
+
if self.args.plots and ni in self.plot_idx:
|
| 417 |
+
self.plot_training_samples(batch, ni)
|
| 418 |
+
|
| 419 |
+
self.run_callbacks("on_train_batch_end")
|
| 420 |
+
|
| 421 |
+
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
| 422 |
+
self.run_callbacks("on_train_epoch_end")
|
| 423 |
+
if RANK in (-1, 0):
|
| 424 |
+
final_epoch = epoch + 1 == self.epochs
|
| 425 |
+
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
| 426 |
+
|
| 427 |
+
# Validation
|
| 428 |
+
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
| 429 |
+
self.metrics, self.fitness = self.validate()
|
| 430 |
+
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
| 431 |
+
self.stop |= self.stopper(epoch + 1, self.fitness)
|
| 432 |
+
if self.args.time:
|
| 433 |
+
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
|
| 434 |
+
|
| 435 |
+
# Save model
|
| 436 |
+
if self.args.save or final_epoch:
|
| 437 |
+
self.save_model()
|
| 438 |
+
self.run_callbacks("on_model_save")
|
| 439 |
+
|
| 440 |
+
# Scheduler
|
| 441 |
+
t = time.time()
|
| 442 |
+
self.epoch_time = t - self.epoch_time_start
|
| 443 |
+
self.epoch_time_start = t
|
| 444 |
+
with warnings.catch_warnings():
|
| 445 |
+
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
| 446 |
+
if self.args.time:
|
| 447 |
+
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
| 448 |
+
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
| 449 |
+
self._setup_scheduler()
|
| 450 |
+
self.scheduler.last_epoch = self.epoch # do not move
|
| 451 |
+
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
| 452 |
+
self.scheduler.step()
|
| 453 |
+
self.run_callbacks("on_fit_epoch_end")
|
| 454 |
+
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
| 455 |
+
|
| 456 |
+
# Early Stopping
|
| 457 |
+
if RANK != -1: # if DDP training
|
| 458 |
+
broadcast_list = [self.stop if RANK == 0 else None]
|
| 459 |
+
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
| 460 |
+
self.stop = broadcast_list[0]
|
| 461 |
+
if self.stop:
|
| 462 |
+
break # must break all DDP ranks
|
| 463 |
+
|
| 464 |
+
if RANK in (-1, 0):
|
| 465 |
+
# Do final val with best.pt
|
| 466 |
+
LOGGER.info(
|
| 467 |
+
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
| 468 |
+
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
| 469 |
+
)
|
| 470 |
+
self.final_eval()
|
| 471 |
+
if self.args.plots:
|
| 472 |
+
self.plot_metrics()
|
| 473 |
+
self.run_callbacks("on_train_end")
|
| 474 |
+
torch.cuda.empty_cache()
|
| 475 |
+
self.run_callbacks("teardown")
|
| 476 |
+
|
| 477 |
+
def save_model(self):
|
| 478 |
+
"""Save model training checkpoints with additional metadata."""
|
| 479 |
+
import pandas as pd # scope for faster startup
|
| 480 |
+
|
| 481 |
+
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
| 482 |
+
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
| 483 |
+
ckpt = {
|
| 484 |
+
"epoch": self.epoch,
|
| 485 |
+
"best_fitness": self.best_fitness,
|
| 486 |
+
"model": deepcopy(de_parallel(self.model)).half(),
|
| 487 |
+
"ema": deepcopy(self.ema.ema).half(),
|
| 488 |
+
"updates": self.ema.updates,
|
| 489 |
+
"optimizer": self.optimizer.state_dict(),
|
| 490 |
+
"train_args": vars(self.args), # save as dict
|
| 491 |
+
"train_metrics": metrics,
|
| 492 |
+
"train_results": results,
|
| 493 |
+
"date": datetime.now().isoformat(),
|
| 494 |
+
"version": __version__,
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
# Save last and best
|
| 498 |
+
torch.save(ckpt, self.last)
|
| 499 |
+
if self.best_fitness == self.fitness:
|
| 500 |
+
torch.save(ckpt, self.best)
|
| 501 |
+
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
| 502 |
+
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
|
| 503 |
+
|
| 504 |
+
@staticmethod
|
| 505 |
+
def get_dataset(data):
|
| 506 |
+
"""
|
| 507 |
+
Get train, val path from data dict if it exists.
|
| 508 |
+
|
| 509 |
+
Returns None if data format is not recognized.
|
| 510 |
+
"""
|
| 511 |
+
return data["train"], data.get("val") or data.get("test")
|
| 512 |
+
|
| 513 |
+
def setup_model(self):
|
| 514 |
+
"""Load/create/download model for any task."""
|
| 515 |
+
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
| 516 |
+
return
|
| 517 |
+
|
| 518 |
+
model, weights = self.model, None
|
| 519 |
+
ckpt = None
|
| 520 |
+
if str(model).endswith(".pt"):
|
| 521 |
+
weights, ckpt = attempt_load_one_weight(model)
|
| 522 |
+
cfg = ckpt["model"].yaml
|
| 523 |
+
else:
|
| 524 |
+
cfg = model
|
| 525 |
+
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
| 526 |
+
return ckpt
|
| 527 |
+
|
| 528 |
+
def optimizer_step(self):
|
| 529 |
+
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
| 530 |
+
self.scaler.unscale_(self.optimizer) # unscale gradients
|
| 531 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
| 532 |
+
self.scaler.step(self.optimizer)
|
| 533 |
+
self.scaler.update()
|
| 534 |
+
self.optimizer.zero_grad()
|
| 535 |
+
if self.ema:
|
| 536 |
+
self.ema.update(self.model)
|
| 537 |
+
|
| 538 |
+
def preprocess_batch(self, batch):
|
| 539 |
+
"""Allows custom preprocessing model inputs and ground truths depending on task type."""
|
| 540 |
+
return batch
|
| 541 |
+
|
| 542 |
+
def validate(self):
|
| 543 |
+
"""
|
| 544 |
+
Runs validation on test set using self.validator.
|
| 545 |
+
|
| 546 |
+
The returned dict is expected to contain "fitness" key.
|
| 547 |
+
"""
|
| 548 |
+
metrics = self.validator(self)
|
| 549 |
+
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
| 550 |
+
if not self.best_fitness or self.best_fitness < fitness:
|
| 551 |
+
self.best_fitness = fitness
|
| 552 |
+
return metrics, fitness
|
| 553 |
+
|
| 554 |
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
| 555 |
+
"""Get model and raise NotImplementedError for loading cfg files."""
|
| 556 |
+
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
| 557 |
+
|
| 558 |
+
def get_validator(self):
|
| 559 |
+
"""Returns a NotImplementedError when the get_validator function is called."""
|
| 560 |
+
raise NotImplementedError("get_validator function not implemented in trainer")
|
| 561 |
+
|
| 562 |
+
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
| 563 |
+
"""Returns dataloader derived from torch.data.Dataloader."""
|
| 564 |
+
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
| 565 |
+
|
| 566 |
+
def build_dataset(self, img_path, mode="train", batch=None):
|
| 567 |
+
"""Build dataset."""
|
| 568 |
+
raise NotImplementedError("build_dataset function not implemented in trainer")
|
| 569 |
+
|
| 570 |
+
def label_loss_items(self, loss_items=None, prefix="train"):
|
| 571 |
+
"""
|
| 572 |
+
Returns a loss dict with labelled training loss items tensor.
|
| 573 |
+
|
| 574 |
+
Note:
|
| 575 |
+
This is not needed for classification but necessary for segmentation & detection
|
| 576 |
+
"""
|
| 577 |
+
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
| 578 |
+
|
| 579 |
+
def set_model_attributes(self):
|
| 580 |
+
"""To set or update model parameters before training."""
|
| 581 |
+
self.model.names = self.data["names"]
|
| 582 |
+
|
| 583 |
+
def build_targets(self, preds, targets):
|
| 584 |
+
"""Builds target tensors for training YOLO model."""
|
| 585 |
+
pass
|
| 586 |
+
|
| 587 |
+
def progress_string(self):
|
| 588 |
+
"""Returns a string describing training progress."""
|
| 589 |
+
return ""
|
| 590 |
+
|
| 591 |
+
# TODO: may need to put these following functions into callback
|
| 592 |
+
def plot_training_samples(self, batch, ni):
|
| 593 |
+
"""Plots training samples during YOLO training."""
|
| 594 |
+
pass
|
| 595 |
+
|
| 596 |
+
def plot_training_labels(self):
|
| 597 |
+
"""Plots training labels for YOLO model."""
|
| 598 |
+
pass
|
| 599 |
+
|
| 600 |
+
def save_metrics(self, metrics):
|
| 601 |
+
"""Saves training metrics to a CSV file."""
|
| 602 |
+
keys, vals = list(metrics.keys()), list(metrics.values())
|
| 603 |
+
n = len(metrics) + 1 # number of cols
|
| 604 |
+
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
|
| 605 |
+
with open(self.csv, "a") as f:
|
| 606 |
+
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
|
| 607 |
+
|
| 608 |
+
def plot_metrics(self):
|
| 609 |
+
"""Plot and display metrics visually."""
|
| 610 |
+
pass
|
| 611 |
+
|
| 612 |
+
def on_plot(self, name, data=None):
|
| 613 |
+
"""Registers plots (e.g. to be consumed in callbacks)"""
|
| 614 |
+
path = Path(name)
|
| 615 |
+
self.plots[path] = {"data": data, "timestamp": time.time()}
|
| 616 |
+
|
| 617 |
+
def final_eval(self):
|
| 618 |
+
"""Performs final evaluation and validation for object detection YOLO model."""
|
| 619 |
+
for f in self.last, self.best:
|
| 620 |
+
if f.exists():
|
| 621 |
+
strip_optimizer(f) # strip optimizers
|
| 622 |
+
if f is self.best:
|
| 623 |
+
LOGGER.info(f"\nValidating {f}...")
|
| 624 |
+
self.validator.args.plots = self.args.plots
|
| 625 |
+
self.metrics = self.validator(model=f)
|
| 626 |
+
self.metrics.pop("fitness", None)
|
| 627 |
+
self.run_callbacks("on_fit_epoch_end")
|
| 628 |
+
|
| 629 |
+
def check_resume(self, overrides):
|
| 630 |
+
"""Check if resume checkpoint exists and update arguments accordingly."""
|
| 631 |
+
resume = self.args.resume
|
| 632 |
+
if resume:
|
| 633 |
+
try:
|
| 634 |
+
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
|
| 635 |
+
last = Path(check_file(resume) if exists else get_latest_run())
|
| 636 |
+
|
| 637 |
+
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
| 638 |
+
ckpt_args = attempt_load_weights(last).args
|
| 639 |
+
if not Path(ckpt_args["data"]).exists():
|
| 640 |
+
ckpt_args["data"] = self.args.data
|
| 641 |
+
|
| 642 |
+
resume = True
|
| 643 |
+
self.args = get_cfg(ckpt_args)
|
| 644 |
+
self.args.model = str(last) # reinstate model
|
| 645 |
+
for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
|
| 646 |
+
if k in overrides:
|
| 647 |
+
setattr(self.args, k, overrides[k])
|
| 648 |
+
|
| 649 |
+
except Exception as e:
|
| 650 |
+
raise FileNotFoundError(
|
| 651 |
+
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
| 652 |
+
"i.e. 'yolo train resume model=path/to/last.pt'"
|
| 653 |
+
) from e
|
| 654 |
+
self.resume = resume
|
| 655 |
+
|
| 656 |
+
def resume_training(self, ckpt):
|
| 657 |
+
"""Resume YOLO training from given epoch and best fitness."""
|
| 658 |
+
if ckpt is None:
|
| 659 |
+
return
|
| 660 |
+
best_fitness = 0.0
|
| 661 |
+
start_epoch = ckpt["epoch"] + 1
|
| 662 |
+
if ckpt["optimizer"] is not None:
|
| 663 |
+
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
| 664 |
+
best_fitness = ckpt["best_fitness"]
|
| 665 |
+
if self.ema and ckpt.get("ema"):
|
| 666 |
+
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
| 667 |
+
self.ema.updates = ckpt["updates"]
|
| 668 |
+
if self.resume:
|
| 669 |
+
assert start_epoch > 0, (
|
| 670 |
+
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
| 671 |
+
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
| 672 |
+
)
|
| 673 |
+
LOGGER.info(
|
| 674 |
+
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
|
| 675 |
+
)
|
| 676 |
+
if self.epochs < start_epoch:
|
| 677 |
+
LOGGER.info(
|
| 678 |
+
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
| 679 |
+
)
|
| 680 |
+
self.epochs += ckpt["epoch"] # finetune additional epochs
|
| 681 |
+
self.best_fitness = best_fitness
|
| 682 |
+
self.start_epoch = start_epoch
|
| 683 |
+
if start_epoch > (self.epochs - self.args.close_mosaic):
|
| 684 |
+
self._close_dataloader_mosaic()
|
| 685 |
+
|
| 686 |
+
def _close_dataloader_mosaic(self):
|
| 687 |
+
"""Update dataloaders to stop using mosaic augmentation."""
|
| 688 |
+
if hasattr(self.train_loader.dataset, "mosaic"):
|
| 689 |
+
self.train_loader.dataset.mosaic = False
|
| 690 |
+
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
| 691 |
+
LOGGER.info("Closing dataloader mosaic")
|
| 692 |
+
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
| 693 |
+
|
| 694 |
+
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
| 695 |
+
"""
|
| 696 |
+
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
| 697 |
+
weight decay, and number of iterations.
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
model (torch.nn.Module): The model for which to build an optimizer.
|
| 701 |
+
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
| 702 |
+
based on the number of iterations. Default: 'auto'.
|
| 703 |
+
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
|
| 704 |
+
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
|
| 705 |
+
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
|
| 706 |
+
iterations (float, optional): The number of iterations, which determines the optimizer if
|
| 707 |
+
name is 'auto'. Default: 1e5.
|
| 708 |
+
|
| 709 |
+
Returns:
|
| 710 |
+
(torch.optim.Optimizer): The constructed optimizer.
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
g = [], [], [] # optimizer parameter groups
|
| 714 |
+
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
| 715 |
+
if name == "auto":
|
| 716 |
+
LOGGER.info(
|
| 717 |
+
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
| 718 |
+
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
| 719 |
+
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
|
| 720 |
+
)
|
| 721 |
+
nc = getattr(model, "nc", 10) # number of classes
|
| 722 |
+
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
| 723 |
+
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
|
| 724 |
+
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
| 725 |
+
|
| 726 |
+
for module_name, module in model.named_modules():
|
| 727 |
+
for param_name, param in module.named_parameters(recurse=False):
|
| 728 |
+
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
| 729 |
+
if "bias" in fullname: # bias (no decay)
|
| 730 |
+
g[2].append(param)
|
| 731 |
+
elif isinstance(module, bn): # weight (no decay)
|
| 732 |
+
g[1].append(param)
|
| 733 |
+
else: # weight (with decay)
|
| 734 |
+
g[0].append(param)
|
| 735 |
+
|
| 736 |
+
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
| 737 |
+
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
| 738 |
+
elif name == "RMSProp":
|
| 739 |
+
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
| 740 |
+
elif name == "SGD":
|
| 741 |
+
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
| 742 |
+
else:
|
| 743 |
+
raise NotImplementedError(
|
| 744 |
+
f"Optimizer '{name}' not found in list of available optimizers "
|
| 745 |
+
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
|
| 746 |
+
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
| 750 |
+
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
| 751 |
+
LOGGER.info(
|
| 752 |
+
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
| 753 |
+
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
| 754 |
+
)
|
| 755 |
+
return optimizer
|
yolov8_model/ultralytics/engine/tuner.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
"""
|
| 3 |
+
This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
|
| 4 |
+
instance segmentation, image classification, pose estimation, and multi-object tracking.
|
| 5 |
+
|
| 6 |
+
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
|
| 7 |
+
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
|
| 8 |
+
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
|
| 9 |
+
|
| 10 |
+
Example:
|
| 11 |
+
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
| 12 |
+
```python
|
| 13 |
+
from ultralytics import YOLO
|
| 14 |
+
|
| 15 |
+
model = YOLO('yolov8n.pt')
|
| 16 |
+
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
|
| 17 |
+
```
|
| 18 |
+
"""
|
| 19 |
+
import random
|
| 20 |
+
import shutil
|
| 21 |
+
import subprocess
|
| 22 |
+
import time
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
|
| 28 |
+
from yolov8_model.ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
|
| 29 |
+
from yolov8_model.ultralytics.utils.plotting import plot_tune_results
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Tuner:
|
| 33 |
+
"""
|
| 34 |
+
Class responsible for hyperparameter tuning of YOLO models.
|
| 35 |
+
|
| 36 |
+
The class evolves YOLO model hyperparameters over a given number of iterations
|
| 37 |
+
by mutating them according to the search space and retraining the model to evaluate their performance.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
|
| 41 |
+
tune_dir (Path): Directory where evolution logs and results will be saved.
|
| 42 |
+
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
| 43 |
+
|
| 44 |
+
Methods:
|
| 45 |
+
_mutate(hyp: dict) -> dict:
|
| 46 |
+
Mutates the given hyperparameters within the bounds specified in `self.space`.
|
| 47 |
+
|
| 48 |
+
__call__():
|
| 49 |
+
Executes the hyperparameter evolution across multiple iterations.
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
| 53 |
+
```python
|
| 54 |
+
from ultralytics import YOLO
|
| 55 |
+
|
| 56 |
+
model = YOLO('yolov8n.pt')
|
| 57 |
+
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Tune with custom search space.
|
| 61 |
+
```python
|
| 62 |
+
from ultralytics import YOLO
|
| 63 |
+
|
| 64 |
+
model = YOLO('yolov8n.pt')
|
| 65 |
+
model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
|
| 66 |
+
```
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
|
| 70 |
+
"""
|
| 71 |
+
Initialize the Tuner with configurations.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
args (dict, optional): Configuration for hyperparameter evolution.
|
| 75 |
+
"""
|
| 76 |
+
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
| 77 |
+
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
| 78 |
+
"lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
| 79 |
+
"lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
|
| 80 |
+
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
|
| 81 |
+
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
|
| 82 |
+
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
| 83 |
+
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
| 84 |
+
"box": (1.0, 20.0), # box loss gain
|
| 85 |
+
"cls": (0.2, 4.0), # cls loss gain (scale with pixels)
|
| 86 |
+
"dfl": (0.4, 6.0), # dfl loss gain
|
| 87 |
+
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
| 88 |
+
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
| 89 |
+
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
| 90 |
+
"degrees": (0.0, 45.0), # image rotation (+/- deg)
|
| 91 |
+
"translate": (0.0, 0.9), # image translation (+/- fraction)
|
| 92 |
+
"scale": (0.0, 0.95), # image scale (+/- gain)
|
| 93 |
+
"shear": (0.0, 10.0), # image shear (+/- deg)
|
| 94 |
+
"perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
| 95 |
+
"flipud": (0.0, 1.0), # image flip up-down (probability)
|
| 96 |
+
"fliplr": (0.0, 1.0), # image flip left-right (probability)
|
| 97 |
+
"mosaic": (0.0, 1.0), # image mixup (probability)
|
| 98 |
+
"mixup": (0.0, 1.0), # image mixup (probability)
|
| 99 |
+
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
| 100 |
+
}
|
| 101 |
+
self.args = get_cfg(overrides=args)
|
| 102 |
+
self.tune_dir = get_save_dir(self.args, name="tune")
|
| 103 |
+
self.tune_csv = self.tune_dir / "tune_results.csv"
|
| 104 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
| 105 |
+
self.prefix = colorstr("Tuner: ")
|
| 106 |
+
callbacks.add_integration_callbacks(self)
|
| 107 |
+
LOGGER.info(
|
| 108 |
+
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
| 109 |
+
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
|
| 113 |
+
"""
|
| 114 |
+
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
parent (str): Parent selection method: 'single' or 'weighted'.
|
| 118 |
+
n (int): Number of parents to consider.
|
| 119 |
+
mutation (float): Probability of a parameter mutation in any given iteration.
|
| 120 |
+
sigma (float): Standard deviation for Gaussian random number generator.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
(dict): A dictionary containing mutated hyperparameters.
|
| 124 |
+
"""
|
| 125 |
+
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
|
| 126 |
+
# Select parent(s)
|
| 127 |
+
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
| 128 |
+
fitness = x[:, 0] # first column
|
| 129 |
+
n = min(n, len(x)) # number of previous results to consider
|
| 130 |
+
x = x[np.argsort(-fitness)][:n] # top n mutations
|
| 131 |
+
w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
|
| 132 |
+
if parent == "single" or len(x) == 1:
|
| 133 |
+
# x = x[random.randint(0, n - 1)] # random selection
|
| 134 |
+
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
|
| 135 |
+
elif parent == "weighted":
|
| 136 |
+
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
| 137 |
+
|
| 138 |
+
# Mutate
|
| 139 |
+
r = np.random # method
|
| 140 |
+
r.seed(int(time.time()))
|
| 141 |
+
g = np.array([v[2] if len(v) == 3 else 1.0 for k, v in self.space.items()]) # gains 0-1
|
| 142 |
+
ng = len(self.space)
|
| 143 |
+
v = np.ones(ng)
|
| 144 |
+
while all(v == 1): # mutate until a change occurs (prevent duplicates)
|
| 145 |
+
v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
|
| 146 |
+
hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
|
| 147 |
+
else:
|
| 148 |
+
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
|
| 149 |
+
|
| 150 |
+
# Constrain to limits
|
| 151 |
+
for k, v in self.space.items():
|
| 152 |
+
hyp[k] = max(hyp[k], v[0]) # lower limit
|
| 153 |
+
hyp[k] = min(hyp[k], v[1]) # upper limit
|
| 154 |
+
hyp[k] = round(hyp[k], 5) # significant digits
|
| 155 |
+
|
| 156 |
+
return hyp
|
| 157 |
+
|
| 158 |
+
def __call__(self, model=None, iterations=10, cleanup=True):
|
| 159 |
+
"""
|
| 160 |
+
Executes the hyperparameter evolution process when the Tuner instance is called.
|
| 161 |
+
|
| 162 |
+
This method iterates through the number of iterations, performing the following steps in each iteration:
|
| 163 |
+
1. Load the existing hyperparameters or initialize new ones.
|
| 164 |
+
2. Mutate the hyperparameters using the `mutate` method.
|
| 165 |
+
3. Train a YOLO model with the mutated hyperparameters.
|
| 166 |
+
4. Log the fitness score and mutated hyperparameters to a CSV file.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
model (Model): A pre-initialized YOLO model to be used for training.
|
| 170 |
+
iterations (int): The number of generations to run the evolution for.
|
| 171 |
+
cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
|
| 172 |
+
|
| 173 |
+
Note:
|
| 174 |
+
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
| 175 |
+
Ensure this path is set correctly in the Tuner instance.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
t0 = time.time()
|
| 179 |
+
best_save_dir, best_metrics = None, None
|
| 180 |
+
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
| 181 |
+
for i in range(iterations):
|
| 182 |
+
# Mutate hyperparameters
|
| 183 |
+
mutated_hyp = self._mutate()
|
| 184 |
+
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
|
| 185 |
+
|
| 186 |
+
metrics = {}
|
| 187 |
+
train_args = {**vars(self.args), **mutated_hyp}
|
| 188 |
+
save_dir = get_save_dir(get_cfg(train_args))
|
| 189 |
+
weights_dir = save_dir / "weights"
|
| 190 |
+
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
|
| 191 |
+
try:
|
| 192 |
+
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
|
| 193 |
+
cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
|
| 194 |
+
return_code = subprocess.run(cmd, check=True).returncode
|
| 195 |
+
metrics = torch.load(ckpt_file)["train_metrics"]
|
| 196 |
+
assert return_code == 0, "training failed"
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
| 200 |
+
|
| 201 |
+
# Save results and mutated_hyp to CSV
|
| 202 |
+
fitness = metrics.get("fitness", 0.0)
|
| 203 |
+
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
| 204 |
+
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
|
| 205 |
+
with open(self.tune_csv, "a") as f:
|
| 206 |
+
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
| 207 |
+
|
| 208 |
+
# Get best results
|
| 209 |
+
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
| 210 |
+
fitness = x[:, 0] # first column
|
| 211 |
+
best_idx = fitness.argmax()
|
| 212 |
+
best_is_current = best_idx == i
|
| 213 |
+
if best_is_current:
|
| 214 |
+
best_save_dir = save_dir
|
| 215 |
+
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
| 216 |
+
for ckpt in weights_dir.glob("*.pt"):
|
| 217 |
+
shutil.copy2(ckpt, self.tune_dir / "weights")
|
| 218 |
+
elif cleanup:
|
| 219 |
+
shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
|
| 220 |
+
|
| 221 |
+
# Plot tune results
|
| 222 |
+
plot_tune_results(self.tune_csv)
|
| 223 |
+
|
| 224 |
+
# Save and print tune results
|
| 225 |
+
header = (
|
| 226 |
+
f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
|
| 227 |
+
f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
|
| 228 |
+
f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
|
| 229 |
+
f'{self.prefix}Best fitness metrics are {best_metrics}\n'
|
| 230 |
+
f'{self.prefix}Best fitness model is {best_save_dir}\n'
|
| 231 |
+
f'{self.prefix}Best fitness hyperparameters are printed below.\n'
|
| 232 |
+
)
|
| 233 |
+
LOGGER.info("\n" + header)
|
| 234 |
+
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
| 235 |
+
yaml_save(
|
| 236 |
+
self.tune_dir / "best_hyperparameters.yaml",
|
| 237 |
+
data=data,
|
| 238 |
+
header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
|
| 239 |
+
)
|
| 240 |
+
yaml_print(self.tune_dir / "best_hyperparameters.yaml")
|
yolov8_model/ultralytics/engine/validator.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
"""
|
| 3 |
+
Check a model's accuracy on a test or val split of a dataset.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
|
| 7 |
+
|
| 8 |
+
Usage - formats:
|
| 9 |
+
$ yolo mode=val model=yolov8n.pt # PyTorch
|
| 10 |
+
yolov8n.torchscript # TorchScript
|
| 11 |
+
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
| 12 |
+
yolov8n_openvino_model # OpenVINO
|
| 13 |
+
yolov8n.engine # TensorRT
|
| 14 |
+
yolov8n.mlpackage # CoreML (macOS-only)
|
| 15 |
+
yolov8n_saved_model # TensorFlow SavedModel
|
| 16 |
+
yolov8n.pb # TensorFlow GraphDef
|
| 17 |
+
yolov8n.tflite # TensorFlow Lite
|
| 18 |
+
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
| 19 |
+
yolov8n_paddle_model # PaddlePaddle
|
| 20 |
+
"""
|
| 21 |
+
import json
|
| 22 |
+
import time
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from yolov8_model.ultralytics.cfg import get_cfg, get_save_dir
|
| 29 |
+
from yolov8_model.ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
| 30 |
+
from yolov8_model.ultralytics.nn.autobackend import AutoBackend
|
| 31 |
+
from yolov8_model.ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
| 32 |
+
from yolov8_model.ultralytics.utils.checks import check_imgsz
|
| 33 |
+
from yolov8_model.ultralytics.utils.ops import Profile
|
| 34 |
+
from yolov8_model.ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BaseValidator:
|
| 38 |
+
"""
|
| 39 |
+
BaseValidator.
|
| 40 |
+
|
| 41 |
+
A base class for creating validators.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
args (SimpleNamespace): Configuration for the validator.
|
| 45 |
+
dataloader (DataLoader): Dataloader to use for validation.
|
| 46 |
+
pbar (tqdm): Progress bar to update during validation.
|
| 47 |
+
model (nn.Module): Model to validate.
|
| 48 |
+
data (dict): Data dictionary.
|
| 49 |
+
device (torch.device): Device to use for validation.
|
| 50 |
+
batch_i (int): Current batch index.
|
| 51 |
+
training (bool): Whether the model is in training mode.
|
| 52 |
+
names (dict): Class names.
|
| 53 |
+
seen: Records the number of images seen so far during validation.
|
| 54 |
+
stats: Placeholder for statistics during validation.
|
| 55 |
+
confusion_matrix: Placeholder for a confusion matrix.
|
| 56 |
+
nc: Number of classes.
|
| 57 |
+
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
| 58 |
+
jdict (dict): Dictionary to store JSON validation results.
|
| 59 |
+
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
| 60 |
+
batch processing times in milliseconds.
|
| 61 |
+
save_dir (Path): Directory to save results.
|
| 62 |
+
plots (dict): Dictionary to store plots for visualization.
|
| 63 |
+
callbacks (dict): Dictionary to store various callback functions.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
| 67 |
+
"""
|
| 68 |
+
Initializes a BaseValidator instance.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
| 72 |
+
save_dir (Path, optional): Directory to save results.
|
| 73 |
+
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
| 74 |
+
args (SimpleNamespace): Configuration for the validator.
|
| 75 |
+
_callbacks (dict): Dictionary to store various callback functions.
|
| 76 |
+
"""
|
| 77 |
+
self.args = get_cfg(overrides=args)
|
| 78 |
+
self.dataloader = dataloader
|
| 79 |
+
self.pbar = pbar
|
| 80 |
+
self.stride = None
|
| 81 |
+
self.data = None
|
| 82 |
+
self.device = None
|
| 83 |
+
self.batch_i = None
|
| 84 |
+
self.training = True
|
| 85 |
+
self.names = None
|
| 86 |
+
self.seen = None
|
| 87 |
+
self.stats = None
|
| 88 |
+
self.confusion_matrix = None
|
| 89 |
+
self.nc = None
|
| 90 |
+
self.iouv = None
|
| 91 |
+
self.jdict = None
|
| 92 |
+
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
| 93 |
+
|
| 94 |
+
self.save_dir = save_dir or get_save_dir(self.args)
|
| 95 |
+
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
| 96 |
+
if self.args.conf is None:
|
| 97 |
+
self.args.conf = 0.001 # default conf=0.001
|
| 98 |
+
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
| 99 |
+
|
| 100 |
+
self.plots = {}
|
| 101 |
+
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
| 102 |
+
|
| 103 |
+
@smart_inference_mode()
|
| 104 |
+
def __call__(self, trainer=None, model=None):
|
| 105 |
+
"""Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
|
| 106 |
+
gets priority).
|
| 107 |
+
"""
|
| 108 |
+
self.training = trainer is not None
|
| 109 |
+
augment = self.args.augment and (not self.training)
|
| 110 |
+
if self.training:
|
| 111 |
+
self.device = trainer.device
|
| 112 |
+
self.data = trainer.data
|
| 113 |
+
self.args.half = self.device.type != "cpu" # force FP16 val during training
|
| 114 |
+
model = trainer.ema.ema or trainer.model
|
| 115 |
+
model = model.half() if self.args.half else model.float()
|
| 116 |
+
# self.model = model
|
| 117 |
+
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
| 118 |
+
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
| 119 |
+
model.eval()
|
| 120 |
+
else:
|
| 121 |
+
callbacks.add_integration_callbacks(self)
|
| 122 |
+
model = AutoBackend(
|
| 123 |
+
model or self.args.model,
|
| 124 |
+
device=select_device(self.args.device, self.args.batch),
|
| 125 |
+
dnn=self.args.dnn,
|
| 126 |
+
data=self.args.data,
|
| 127 |
+
fp16=self.args.half,
|
| 128 |
+
)
|
| 129 |
+
# self.model = model
|
| 130 |
+
self.device = model.device # update device
|
| 131 |
+
self.args.half = model.fp16 # update half
|
| 132 |
+
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
| 133 |
+
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
| 134 |
+
if engine:
|
| 135 |
+
self.args.batch = model.batch_size
|
| 136 |
+
elif not pt and not jit:
|
| 137 |
+
self.args.batch = 1 # export.py models default to batch-size 1
|
| 138 |
+
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
| 139 |
+
|
| 140 |
+
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
| 141 |
+
self.data = check_det_dataset(self.args.data)
|
| 142 |
+
elif self.args.task == "classify":
|
| 143 |
+
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
| 144 |
+
else:
|
| 145 |
+
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
| 146 |
+
|
| 147 |
+
if self.device.type in ("cpu", "mps"):
|
| 148 |
+
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
| 149 |
+
if not pt:
|
| 150 |
+
self.args.rect = False
|
| 151 |
+
self.stride = model.stride # used in get_dataloader() for padding
|
| 152 |
+
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
| 153 |
+
|
| 154 |
+
model.eval()
|
| 155 |
+
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
| 156 |
+
|
| 157 |
+
self.run_callbacks("on_val_start")
|
| 158 |
+
dt = (
|
| 159 |
+
Profile(device=self.device),
|
| 160 |
+
Profile(device=self.device),
|
| 161 |
+
Profile(device=self.device),
|
| 162 |
+
Profile(device=self.device),
|
| 163 |
+
)
|
| 164 |
+
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
| 165 |
+
self.init_metrics(de_parallel(model))
|
| 166 |
+
self.jdict = [] # empty before each val
|
| 167 |
+
for batch_i, batch in enumerate(bar):
|
| 168 |
+
self.run_callbacks("on_val_batch_start")
|
| 169 |
+
self.batch_i = batch_i
|
| 170 |
+
# Preprocess
|
| 171 |
+
with dt[0]:
|
| 172 |
+
batch = self.preprocess(batch)
|
| 173 |
+
|
| 174 |
+
# Inference
|
| 175 |
+
with dt[1]:
|
| 176 |
+
preds = model(batch["img"], augment=augment)
|
| 177 |
+
|
| 178 |
+
# Loss
|
| 179 |
+
with dt[2]:
|
| 180 |
+
if self.training:
|
| 181 |
+
self.loss += model.loss(batch, preds)[1]
|
| 182 |
+
|
| 183 |
+
# Postprocess
|
| 184 |
+
with dt[3]:
|
| 185 |
+
preds = self.postprocess(preds)
|
| 186 |
+
|
| 187 |
+
self.update_metrics(preds, batch)
|
| 188 |
+
if self.args.plots and batch_i < 3:
|
| 189 |
+
self.plot_val_samples(batch, batch_i)
|
| 190 |
+
self.plot_predictions(batch, preds, batch_i)
|
| 191 |
+
|
| 192 |
+
self.run_callbacks("on_val_batch_end")
|
| 193 |
+
stats = self.get_stats()
|
| 194 |
+
self.check_stats(stats)
|
| 195 |
+
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
| 196 |
+
self.finalize_metrics()
|
| 197 |
+
self.print_results()
|
| 198 |
+
self.run_callbacks("on_val_end")
|
| 199 |
+
if self.training:
|
| 200 |
+
model.float()
|
| 201 |
+
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
| 202 |
+
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
| 203 |
+
else:
|
| 204 |
+
LOGGER.info(
|
| 205 |
+
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
|
| 206 |
+
% tuple(self.speed.values())
|
| 207 |
+
)
|
| 208 |
+
if self.args.save_json and self.jdict:
|
| 209 |
+
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
| 210 |
+
LOGGER.info(f"Saving {f.name}...")
|
| 211 |
+
json.dump(self.jdict, f) # flatten and save
|
| 212 |
+
stats = self.eval_json(stats) # update stats
|
| 213 |
+
if self.args.plots or self.args.save_json:
|
| 214 |
+
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
| 215 |
+
return stats
|
| 216 |
+
|
| 217 |
+
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
|
| 218 |
+
"""
|
| 219 |
+
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
|
| 223 |
+
true_classes (torch.Tensor): Target class indices of shape(M,).
|
| 224 |
+
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
|
| 225 |
+
use_scipy (bool): Whether to use scipy for matching (more precise).
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
|
| 229 |
+
"""
|
| 230 |
+
# Dx10 matrix, where D - detections, 10 - IoU thresholds
|
| 231 |
+
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
| 232 |
+
# LxD matrix where L - labels (rows), D - detections (columns)
|
| 233 |
+
correct_class = true_classes[:, None] == pred_classes
|
| 234 |
+
iou = iou * correct_class # zero out the wrong classes
|
| 235 |
+
iou = iou.cpu().numpy()
|
| 236 |
+
for i, threshold in enumerate(self.iouv.cpu().tolist()):
|
| 237 |
+
if use_scipy:
|
| 238 |
+
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
| 239 |
+
import scipy # scope import to avoid importing for all commands
|
| 240 |
+
|
| 241 |
+
cost_matrix = iou * (iou >= threshold)
|
| 242 |
+
if cost_matrix.any():
|
| 243 |
+
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
|
| 244 |
+
valid = cost_matrix[labels_idx, detections_idx] > 0
|
| 245 |
+
if valid.any():
|
| 246 |
+
correct[detections_idx[valid], i] = True
|
| 247 |
+
else:
|
| 248 |
+
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
|
| 249 |
+
matches = np.array(matches).T
|
| 250 |
+
if matches.shape[0]:
|
| 251 |
+
if matches.shape[0] > 1:
|
| 252 |
+
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
| 253 |
+
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
| 254 |
+
# matches = matches[matches[:, 2].argsort()[::-1]]
|
| 255 |
+
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
| 256 |
+
correct[matches[:, 1].astype(int), i] = True
|
| 257 |
+
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
| 258 |
+
|
| 259 |
+
def add_callback(self, event: str, callback):
|
| 260 |
+
"""Appends the given callback."""
|
| 261 |
+
self.callbacks[event].append(callback)
|
| 262 |
+
|
| 263 |
+
def run_callbacks(self, event: str):
|
| 264 |
+
"""Runs all callbacks associated with a specified event."""
|
| 265 |
+
for callback in self.callbacks.get(event, []):
|
| 266 |
+
callback(self)
|
| 267 |
+
|
| 268 |
+
def get_dataloader(self, dataset_path, batch_size):
|
| 269 |
+
"""Get data loader from dataset path and batch size."""
|
| 270 |
+
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
| 271 |
+
|
| 272 |
+
def build_dataset(self, img_path):
|
| 273 |
+
"""Build dataset."""
|
| 274 |
+
raise NotImplementedError("build_dataset function not implemented in validator")
|
| 275 |
+
|
| 276 |
+
def preprocess(self, batch):
|
| 277 |
+
"""Preprocesses an input batch."""
|
| 278 |
+
return batch
|
| 279 |
+
|
| 280 |
+
def postprocess(self, preds):
|
| 281 |
+
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
| 282 |
+
return preds
|
| 283 |
+
|
| 284 |
+
def init_metrics(self, model):
|
| 285 |
+
"""Initialize performance metrics for the YOLO model."""
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
def update_metrics(self, preds, batch):
|
| 289 |
+
"""Updates metrics based on predictions and batch."""
|
| 290 |
+
pass
|
| 291 |
+
|
| 292 |
+
def finalize_metrics(self, *args, **kwargs):
|
| 293 |
+
"""Finalizes and returns all metrics."""
|
| 294 |
+
pass
|
| 295 |
+
|
| 296 |
+
def get_stats(self):
|
| 297 |
+
"""Returns statistics about the model's performance."""
|
| 298 |
+
return {}
|
| 299 |
+
|
| 300 |
+
def check_stats(self, stats):
|
| 301 |
+
"""Checks statistics."""
|
| 302 |
+
pass
|
| 303 |
+
|
| 304 |
+
def print_results(self):
|
| 305 |
+
"""Prints the results of the model's predictions."""
|
| 306 |
+
pass
|
| 307 |
+
|
| 308 |
+
def get_desc(self):
|
| 309 |
+
"""Get description of the YOLO model."""
|
| 310 |
+
pass
|
| 311 |
+
|
| 312 |
+
@property
|
| 313 |
+
def metric_keys(self):
|
| 314 |
+
"""Returns the metric keys used in YOLO training/validation."""
|
| 315 |
+
return []
|
| 316 |
+
|
| 317 |
+
def on_plot(self, name, data=None):
|
| 318 |
+
"""Registers plots (e.g. to be consumed in callbacks)"""
|
| 319 |
+
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
| 320 |
+
|
| 321 |
+
# TODO: may need to put these following functions into callback
|
| 322 |
+
def plot_val_samples(self, batch, ni):
|
| 323 |
+
"""Plots validation samples during training."""
|
| 324 |
+
pass
|
| 325 |
+
|
| 326 |
+
def plot_predictions(self, batch, preds, ni):
|
| 327 |
+
"""Plots YOLO model predictions on batch images."""
|
| 328 |
+
pass
|
| 329 |
+
|
| 330 |
+
def pred_to_json(self, preds, batch):
|
| 331 |
+
"""Convert predictions to JSON format."""
|
| 332 |
+
pass
|
| 333 |
+
|
| 334 |
+
def eval_json(self, stats):
|
| 335 |
+
"""Evaluate and return JSON format of prediction statistics."""
|
| 336 |
+
pass
|
yolov8_model/ultralytics/hub/__init__.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from yolov8_model.ultralytics.data.utils import HUBDatasetStats
|
| 6 |
+
from yolov8_model.ultralytics.hub.auth import Auth
|
| 7 |
+
from yolov8_model.ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
| 8 |
+
from yolov8_model.ultralytics.utils import LOGGER, SETTINGS, checks
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def login(api_key: str = None, save=True) -> bool:
|
| 12 |
+
"""
|
| 13 |
+
Log in to the Ultralytics HUB API using the provided API key.
|
| 14 |
+
|
| 15 |
+
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
|
| 16 |
+
environment variable if successfully authenticated.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
api_key (str, optional): API key to use for authentication.
|
| 20 |
+
If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
|
| 21 |
+
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
(bool): True if authentication is successful, False otherwise.
|
| 25 |
+
"""
|
| 26 |
+
checks.check_requirements("hub-sdk>=0.0.2")
|
| 27 |
+
from hub_sdk import HUBClient
|
| 28 |
+
|
| 29 |
+
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
| 30 |
+
saved_key = SETTINGS.get("api_key")
|
| 31 |
+
active_key = api_key or saved_key
|
| 32 |
+
credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
|
| 33 |
+
|
| 34 |
+
client = HUBClient(credentials) # initialize HUBClient
|
| 35 |
+
|
| 36 |
+
if client.authenticated:
|
| 37 |
+
# Successfully authenticated with HUB
|
| 38 |
+
|
| 39 |
+
if save and client.api_key != saved_key:
|
| 40 |
+
SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
|
| 41 |
+
|
| 42 |
+
# Set message based on whether key was provided or retrieved from settings
|
| 43 |
+
log_message = (
|
| 44 |
+
"New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
|
| 45 |
+
)
|
| 46 |
+
LOGGER.info(f"{PREFIX}{log_message}")
|
| 47 |
+
|
| 48 |
+
return True
|
| 49 |
+
else:
|
| 50 |
+
# Failed to authenticate with HUB
|
| 51 |
+
LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}")
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def logout():
|
| 56 |
+
"""
|
| 57 |
+
Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
|
| 58 |
+
|
| 59 |
+
Example:
|
| 60 |
+
```python
|
| 61 |
+
from ultralytics import hub
|
| 62 |
+
|
| 63 |
+
hub.logout()
|
| 64 |
+
```
|
| 65 |
+
"""
|
| 66 |
+
SETTINGS["api_key"] = ""
|
| 67 |
+
SETTINGS.save()
|
| 68 |
+
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def reset_model(model_id=""):
|
| 72 |
+
"""Reset a trained model to an untrained state."""
|
| 73 |
+
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
|
| 74 |
+
if r.status_code == 200:
|
| 75 |
+
LOGGER.info(f"{PREFIX}Model reset successfully")
|
| 76 |
+
return
|
| 77 |
+
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def export_fmts_hub():
|
| 81 |
+
"""Returns a list of HUB-supported export formats."""
|
| 82 |
+
from ultralytics.engine.exporter import export_formats
|
| 83 |
+
|
| 84 |
+
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def export_model(model_id="", format="torchscript"):
|
| 88 |
+
"""Export a model to all formats."""
|
| 89 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
| 90 |
+
r = requests.post(
|
| 91 |
+
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
|
| 92 |
+
)
|
| 93 |
+
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
|
| 94 |
+
LOGGER.info(f"{PREFIX}{format} export started ✅")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_export(model_id="", format="torchscript"):
|
| 98 |
+
"""Get an exported model dictionary with download URL."""
|
| 99 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
| 100 |
+
r = requests.post(
|
| 101 |
+
f"{HUB_API_ROOT}/get-export",
|
| 102 |
+
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
|
| 103 |
+
headers={"x-api-key": Auth().api_key},
|
| 104 |
+
)
|
| 105 |
+
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
|
| 106 |
+
return r.json()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def check_dataset(path="", task="detect"):
|
| 110 |
+
"""
|
| 111 |
+
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
|
| 112 |
+
to the HUB. Usage examples are given below.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
|
| 116 |
+
task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'.
|
| 117 |
+
|
| 118 |
+
Example:
|
| 119 |
+
```python
|
| 120 |
+
from ultralytics.hub import check_dataset
|
| 121 |
+
|
| 122 |
+
check_dataset('path/to/coco8.zip', task='detect') # detect dataset
|
| 123 |
+
check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
|
| 124 |
+
check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
|
| 125 |
+
```
|
| 126 |
+
"""
|
| 127 |
+
HUBDatasetStats(path=path, task=task).get_json()
|
| 128 |
+
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
|
yolov8_model/ultralytics/hub/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (5.02 kB). View file
|
|
|
yolov8_model/ultralytics/hub/__pycache__/auth.cpython-310.pyc
ADDED
|
Binary file (4.32 kB). View file
|
|
|
yolov8_model/ultralytics/hub/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (8.53 kB). View file
|
|
|
yolov8_model/ultralytics/hub/auth.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from yolov8_model.ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
|
| 6 |
+
from yolov8_model.ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
|
| 7 |
+
|
| 8 |
+
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Auth:
|
| 12 |
+
"""
|
| 13 |
+
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
| 14 |
+
|
| 15 |
+
The class supports different methods of authentication:
|
| 16 |
+
1. Directly using an API key.
|
| 17 |
+
2. Authenticating using browser cookies (specifically in Google Colab).
|
| 18 |
+
3. Prompting the user to enter an API key.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
id_token (str or bool): Token used for identity verification, initialized as False.
|
| 22 |
+
api_key (str or bool): API key for authentication, initialized as False.
|
| 23 |
+
model_key (bool): Placeholder for model key, initialized as False.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
id_token = api_key = model_key = False
|
| 27 |
+
|
| 28 |
+
def __init__(self, api_key="", verbose=False):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the Auth class with an optional API key.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
| 34 |
+
"""
|
| 35 |
+
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
| 36 |
+
api_key = api_key.split("_")[0]
|
| 37 |
+
|
| 38 |
+
# Set API key attribute as value passed or SETTINGS API key if none passed
|
| 39 |
+
self.api_key = api_key or SETTINGS.get("api_key", "")
|
| 40 |
+
|
| 41 |
+
# If an API key is provided
|
| 42 |
+
if self.api_key:
|
| 43 |
+
# If the provided API key matches the API key in the SETTINGS
|
| 44 |
+
if self.api_key == SETTINGS.get("api_key"):
|
| 45 |
+
# Log that the user is already logged in
|
| 46 |
+
if verbose:
|
| 47 |
+
LOGGER.info(f"{PREFIX}Authenticated ✅")
|
| 48 |
+
return
|
| 49 |
+
else:
|
| 50 |
+
# Attempt to authenticate with the provided API key
|
| 51 |
+
success = self.authenticate()
|
| 52 |
+
# If the API key is not provided and the environment is a Google Colab notebook
|
| 53 |
+
elif is_colab():
|
| 54 |
+
# Attempt to authenticate using browser cookies
|
| 55 |
+
success = self.auth_with_cookies()
|
| 56 |
+
else:
|
| 57 |
+
# Request an API key
|
| 58 |
+
success = self.request_api_key()
|
| 59 |
+
|
| 60 |
+
# Update SETTINGS with the new API key after successful authentication
|
| 61 |
+
if success:
|
| 62 |
+
SETTINGS.update({"api_key": self.api_key})
|
| 63 |
+
# Log that the new login was successful
|
| 64 |
+
if verbose:
|
| 65 |
+
LOGGER.info(f"{PREFIX}New authentication successful ✅")
|
| 66 |
+
elif verbose:
|
| 67 |
+
LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}")
|
| 68 |
+
|
| 69 |
+
def request_api_key(self, max_attempts=3):
|
| 70 |
+
"""
|
| 71 |
+
Prompt the user to input their API key.
|
| 72 |
+
|
| 73 |
+
Returns the model ID.
|
| 74 |
+
"""
|
| 75 |
+
import getpass
|
| 76 |
+
|
| 77 |
+
for attempts in range(max_attempts):
|
| 78 |
+
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
|
| 79 |
+
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
|
| 80 |
+
self.api_key = input_key.split("_")[0] # remove model id if present
|
| 81 |
+
if self.authenticate():
|
| 82 |
+
return True
|
| 83 |
+
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
| 84 |
+
|
| 85 |
+
def authenticate(self) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Attempt to authenticate with the server using either id_token or API key.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
(bool): True if authentication is successful, False otherwise.
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
if header := self.get_auth_header():
|
| 94 |
+
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
|
| 95 |
+
if not r.json().get("success", False):
|
| 96 |
+
raise ConnectionError("Unable to authenticate.")
|
| 97 |
+
return True
|
| 98 |
+
raise ConnectionError("User has not authenticated locally.")
|
| 99 |
+
except ConnectionError:
|
| 100 |
+
self.id_token = self.api_key = False # reset invalid
|
| 101 |
+
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
def auth_with_cookies(self) -> bool:
|
| 105 |
+
"""
|
| 106 |
+
Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
|
| 107 |
+
supported browser.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
(bool): True if authentication is successful, False otherwise.
|
| 111 |
+
"""
|
| 112 |
+
if not is_colab():
|
| 113 |
+
return False # Currently only works with Colab
|
| 114 |
+
try:
|
| 115 |
+
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
|
| 116 |
+
if authn.get("success", False):
|
| 117 |
+
self.id_token = authn.get("data", {}).get("idToken", None)
|
| 118 |
+
self.authenticate()
|
| 119 |
+
return True
|
| 120 |
+
raise ConnectionError("Unable to fetch browser authentication details.")
|
| 121 |
+
except ConnectionError:
|
| 122 |
+
self.id_token = False # reset invalid
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
def get_auth_header(self):
|
| 126 |
+
"""
|
| 127 |
+
Get the authentication header for making API requests.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
(dict): The authentication header if id_token or API key is set, None otherwise.
|
| 131 |
+
"""
|
| 132 |
+
if self.id_token:
|
| 133 |
+
return {"authorization": f"Bearer {self.id_token}"}
|
| 134 |
+
elif self.api_key:
|
| 135 |
+
return {"x-api-key": self.api_key}
|
| 136 |
+
# else returns None
|
yolov8_model/ultralytics/hub/session.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import threading
|
| 4 |
+
import time
|
| 5 |
+
from http import HTTPStatus
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
from yolov8_model.ultralytics.hub.utils import HUB_WEB_ROOT, HELP_MSG, PREFIX, TQDM
|
| 11 |
+
from yolov8_model.ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
|
| 12 |
+
from yolov8_model.ultralytics.utils.errors import HUBModelError
|
| 13 |
+
|
| 14 |
+
AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HUBTrainingSession:
|
| 18 |
+
"""
|
| 19 |
+
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
agent_id (str): Identifier for the instance communicating with the server.
|
| 23 |
+
model_id (str): Identifier for the YOLO model being trained.
|
| 24 |
+
model_url (str): URL for the model in Ultralytics HUB.
|
| 25 |
+
api_url (str): API URL for the model in Ultralytics HUB.
|
| 26 |
+
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
|
| 27 |
+
rate_limits (dict): Rate limits for different API calls (in seconds).
|
| 28 |
+
timers (dict): Timers for rate limiting.
|
| 29 |
+
metrics_queue (dict): Queue for the model's metrics.
|
| 30 |
+
model (dict): Model data fetched from Ultralytics HUB.
|
| 31 |
+
alive (bool): Indicates if the heartbeat loop is active.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, identifier):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the HUBTrainingSession with the provided model identifier.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
identifier (str): Model identifier used to initialize the HUB training session.
|
| 40 |
+
It can be a URL string or a model key with specific format.
|
| 41 |
+
|
| 42 |
+
Raises:
|
| 43 |
+
ValueError: If the provided model identifier is invalid.
|
| 44 |
+
ConnectionError: If connecting with global API key is not supported.
|
| 45 |
+
ModuleNotFoundError: If hub-sdk package is not installed.
|
| 46 |
+
"""
|
| 47 |
+
from hub_sdk import HUBClient
|
| 48 |
+
|
| 49 |
+
self.rate_limits = {
|
| 50 |
+
"metrics": 3.0,
|
| 51 |
+
"ckpt": 900.0,
|
| 52 |
+
"heartbeat": 300.0,
|
| 53 |
+
} # rate limits (seconds)
|
| 54 |
+
self.metrics_queue = {} # holds metrics for each epoch until upload
|
| 55 |
+
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
| 56 |
+
|
| 57 |
+
# Parse input
|
| 58 |
+
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
| 59 |
+
|
| 60 |
+
# Get credentials
|
| 61 |
+
active_key = api_key or SETTINGS.get("api_key")
|
| 62 |
+
credentials = {"api_key": active_key} if active_key else None # set credentials
|
| 63 |
+
|
| 64 |
+
# Initialize client
|
| 65 |
+
self.client = HUBClient(credentials)
|
| 66 |
+
|
| 67 |
+
if model_id:
|
| 68 |
+
self.load_model(model_id) # load existing model
|
| 69 |
+
else:
|
| 70 |
+
self.model = self.client.model() # load empty model
|
| 71 |
+
|
| 72 |
+
def load_model(self, model_id):
|
| 73 |
+
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
|
| 74 |
+
self.model = self.client.model(model_id)
|
| 75 |
+
if not self.model.data: # then model does not exist
|
| 76 |
+
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
|
| 77 |
+
|
| 78 |
+
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
| 79 |
+
|
| 80 |
+
self._set_train_args()
|
| 81 |
+
|
| 82 |
+
# Start heartbeats for HUB to monitor agent
|
| 83 |
+
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
| 84 |
+
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
| 85 |
+
|
| 86 |
+
def create_model(self, model_args):
|
| 87 |
+
"""Initializes a HUB training session with the specified model identifier."""
|
| 88 |
+
payload = {
|
| 89 |
+
"config": {
|
| 90 |
+
"batchSize": model_args.get("batch", -1),
|
| 91 |
+
"epochs": model_args.get("epochs", 300),
|
| 92 |
+
"imageSize": model_args.get("imgsz", 640),
|
| 93 |
+
"patience": model_args.get("patience", 100),
|
| 94 |
+
"device": model_args.get("device", ""),
|
| 95 |
+
"cache": model_args.get("cache", "ram"),
|
| 96 |
+
},
|
| 97 |
+
"dataset": {"name": model_args.get("data")},
|
| 98 |
+
"lineage": {
|
| 99 |
+
"architecture": {
|
| 100 |
+
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
|
| 101 |
+
},
|
| 102 |
+
"parent": {},
|
| 103 |
+
},
|
| 104 |
+
"meta": {"name": self.filename},
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
if self.filename.endswith(".pt"):
|
| 108 |
+
payload["lineage"]["parent"]["name"] = self.filename
|
| 109 |
+
|
| 110 |
+
self.model.create_model(payload)
|
| 111 |
+
|
| 112 |
+
# Model could not be created
|
| 113 |
+
# TODO: improve error handling
|
| 114 |
+
if not self.model.id:
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
| 118 |
+
|
| 119 |
+
# Start heartbeats for HUB to monitor agent
|
| 120 |
+
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
| 121 |
+
|
| 122 |
+
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
| 123 |
+
|
| 124 |
+
def _parse_identifier(self, identifier):
|
| 125 |
+
"""
|
| 126 |
+
Parses the given identifier to determine the type of identifier and extract relevant components.
|
| 127 |
+
|
| 128 |
+
The method supports different identifier formats:
|
| 129 |
+
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
|
| 130 |
+
- An identifier containing an API key and a model ID separated by an underscore
|
| 131 |
+
- An identifier that is solely a model ID of a fixed length
|
| 132 |
+
- A local filename that ends with '.pt' or '.yaml'
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
identifier (str): The identifier string to be parsed.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
(tuple): A tuple containing the API key, model ID, and filename as applicable.
|
| 139 |
+
|
| 140 |
+
Raises:
|
| 141 |
+
HUBModelError: If the identifier format is not recognized.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
# Initialize variables
|
| 145 |
+
api_key, model_id, filename = None, None, None
|
| 146 |
+
|
| 147 |
+
# Check if identifier is a HUB URL
|
| 148 |
+
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
| 149 |
+
# Extract the model_id after the HUB_WEB_ROOT URL
|
| 150 |
+
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
|
| 151 |
+
else:
|
| 152 |
+
# Split the identifier based on underscores only if it's not a HUB URL
|
| 153 |
+
parts = identifier.split("_")
|
| 154 |
+
|
| 155 |
+
# Check if identifier is in the format of API key and model ID
|
| 156 |
+
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
| 157 |
+
api_key, model_id = parts
|
| 158 |
+
# Check if identifier is a single model ID
|
| 159 |
+
elif len(parts) == 1 and len(parts[0]) == 20:
|
| 160 |
+
model_id = parts[0]
|
| 161 |
+
# Check if identifier is a local filename
|
| 162 |
+
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
| 163 |
+
filename = identifier
|
| 164 |
+
else:
|
| 165 |
+
raise HUBModelError(
|
| 166 |
+
f"model='{identifier}' could not be parsed. Check format is correct. "
|
| 167 |
+
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return api_key, model_id, filename
|
| 171 |
+
|
| 172 |
+
def _set_train_args(self, **kwargs):
|
| 173 |
+
"""Initializes training arguments and creates a model entry on the Ultralytics HUB."""
|
| 174 |
+
if self.model.is_trained():
|
| 175 |
+
# Model is already trained
|
| 176 |
+
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
| 177 |
+
|
| 178 |
+
if self.model.is_resumable():
|
| 179 |
+
# Model has saved weights
|
| 180 |
+
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
| 181 |
+
self.model_file = self.model.get_weights_url("last")
|
| 182 |
+
else:
|
| 183 |
+
# Model has no saved weights
|
| 184 |
+
def get_train_args(config):
|
| 185 |
+
"""Parses an identifier to extract API key, model ID, and filename if applicable."""
|
| 186 |
+
return {
|
| 187 |
+
"batch": config["batchSize"],
|
| 188 |
+
"epochs": config["epochs"],
|
| 189 |
+
"imgsz": config["imageSize"],
|
| 190 |
+
"patience": config["patience"],
|
| 191 |
+
"device": config["device"],
|
| 192 |
+
"cache": config["cache"],
|
| 193 |
+
"data": self.model.get_dataset_url(),
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
self.train_args = get_train_args(self.model.data.get("config"))
|
| 197 |
+
# Set the model file as either a *.pt or *.yaml file
|
| 198 |
+
self.model_file = (
|
| 199 |
+
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if not self.train_args.get("data"):
|
| 203 |
+
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
|
| 204 |
+
|
| 205 |
+
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
| 206 |
+
self.model_id = self.model.id
|
| 207 |
+
|
| 208 |
+
def request_queue(
|
| 209 |
+
self,
|
| 210 |
+
request_func,
|
| 211 |
+
retry=3,
|
| 212 |
+
timeout=30,
|
| 213 |
+
thread=True,
|
| 214 |
+
verbose=True,
|
| 215 |
+
progress_total=None,
|
| 216 |
+
*args,
|
| 217 |
+
**kwargs,
|
| 218 |
+
):
|
| 219 |
+
def retry_request():
|
| 220 |
+
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
|
| 221 |
+
t0 = time.time() # Record the start time for the timeout
|
| 222 |
+
for i in range(retry + 1):
|
| 223 |
+
if (time.time() - t0) > timeout:
|
| 224 |
+
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
| 225 |
+
break # Timeout reached, exit loop
|
| 226 |
+
|
| 227 |
+
response = request_func(*args, **kwargs)
|
| 228 |
+
if response is None:
|
| 229 |
+
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
|
| 230 |
+
time.sleep(2**i) # Exponential backoff before retrying
|
| 231 |
+
continue # Skip further processing and retry
|
| 232 |
+
|
| 233 |
+
if progress_total:
|
| 234 |
+
self._show_upload_progress(progress_total, response)
|
| 235 |
+
|
| 236 |
+
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
| 237 |
+
return response # Success, no need to retry
|
| 238 |
+
|
| 239 |
+
if i == 0:
|
| 240 |
+
# Initial attempt, check status code and provide messages
|
| 241 |
+
message = self._get_failure_message(response, retry, timeout)
|
| 242 |
+
|
| 243 |
+
if verbose:
|
| 244 |
+
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
|
| 245 |
+
|
| 246 |
+
if not self._should_retry(response.status_code):
|
| 247 |
+
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
|
| 248 |
+
break # Not an error that should be retried, exit loop
|
| 249 |
+
|
| 250 |
+
time.sleep(2**i) # Exponential backoff for retries
|
| 251 |
+
|
| 252 |
+
return response
|
| 253 |
+
|
| 254 |
+
if thread:
|
| 255 |
+
# Start a new thread to run the retry_request function
|
| 256 |
+
threading.Thread(target=retry_request, daemon=True).start()
|
| 257 |
+
else:
|
| 258 |
+
# If running in the main thread, call retry_request directly
|
| 259 |
+
return retry_request()
|
| 260 |
+
|
| 261 |
+
def _should_retry(self, status_code):
|
| 262 |
+
"""Determines if a request should be retried based on the HTTP status code."""
|
| 263 |
+
retry_codes = {
|
| 264 |
+
HTTPStatus.REQUEST_TIMEOUT,
|
| 265 |
+
HTTPStatus.BAD_GATEWAY,
|
| 266 |
+
HTTPStatus.GATEWAY_TIMEOUT,
|
| 267 |
+
}
|
| 268 |
+
return status_code in retry_codes
|
| 269 |
+
|
| 270 |
+
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
|
| 271 |
+
"""
|
| 272 |
+
Generate a retry message based on the response status code.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
response: The HTTP response object.
|
| 276 |
+
retry: The number of retry attempts allowed.
|
| 277 |
+
timeout: The maximum timeout duration.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
(str): The retry message.
|
| 281 |
+
"""
|
| 282 |
+
if self._should_retry(response.status_code):
|
| 283 |
+
return f"Retrying {retry}x for {timeout}s." if retry else ""
|
| 284 |
+
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
|
| 285 |
+
headers = response.headers
|
| 286 |
+
return (
|
| 287 |
+
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
| 288 |
+
f"Please retry after {headers['Retry-After']}s."
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
try:
|
| 292 |
+
return response.json().get("message", "No JSON message.")
|
| 293 |
+
except AttributeError:
|
| 294 |
+
return "Unable to read JSON."
|
| 295 |
+
|
| 296 |
+
def upload_metrics(self):
|
| 297 |
+
"""Upload model metrics to Ultralytics HUB."""
|
| 298 |
+
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
|
| 299 |
+
|
| 300 |
+
def upload_model(
|
| 301 |
+
self,
|
| 302 |
+
epoch: int,
|
| 303 |
+
weights: str,
|
| 304 |
+
is_best: bool = False,
|
| 305 |
+
map: float = 0.0,
|
| 306 |
+
final: bool = False,
|
| 307 |
+
) -> None:
|
| 308 |
+
"""
|
| 309 |
+
Upload a model checkpoint to Ultralytics HUB.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
epoch (int): The current training epoch.
|
| 313 |
+
weights (str): Path to the model weights file.
|
| 314 |
+
is_best (bool): Indicates if the current model is the best one so far.
|
| 315 |
+
map (float): Mean average precision of the model.
|
| 316 |
+
final (bool): Indicates if the model is the final model after training.
|
| 317 |
+
"""
|
| 318 |
+
if Path(weights).is_file():
|
| 319 |
+
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
|
| 320 |
+
self.request_queue(
|
| 321 |
+
self.model.upload_model,
|
| 322 |
+
epoch=epoch,
|
| 323 |
+
weights=weights,
|
| 324 |
+
is_best=is_best,
|
| 325 |
+
map=map,
|
| 326 |
+
final=final,
|
| 327 |
+
retry=10,
|
| 328 |
+
timeout=3600,
|
| 329 |
+
thread=not final,
|
| 330 |
+
progress_total=progress_total,
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
| 334 |
+
|
| 335 |
+
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
| 336 |
+
"""
|
| 337 |
+
Display a progress bar to track the upload progress of a file download.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
content_length (int): The total size of the content to be downloaded in bytes.
|
| 341 |
+
response (requests.Response): The response object from the file download request.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
None
|
| 345 |
+
"""
|
| 346 |
+
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
| 347 |
+
for data in response.iter_content(chunk_size=1024):
|
| 348 |
+
pbar.update(len(data))
|
yolov8_model/ultralytics/hub/utils.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import platform
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
|
| 13 |
+
from yolov8_model.ultralytics.utils import (
|
| 14 |
+
ENVIRONMENT,
|
| 15 |
+
LOGGER,
|
| 16 |
+
ONLINE,
|
| 17 |
+
RANK,
|
| 18 |
+
SETTINGS,
|
| 19 |
+
TESTS_RUNNING,
|
| 20 |
+
TQDM,
|
| 21 |
+
TryExcept,
|
| 22 |
+
__version__,
|
| 23 |
+
colorstr,
|
| 24 |
+
get_git_origin_url,
|
| 25 |
+
is_colab,
|
| 26 |
+
is_git_dir,
|
| 27 |
+
is_pip_package,
|
| 28 |
+
)
|
| 29 |
+
from yolov8_model.ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
| 30 |
+
|
| 31 |
+
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
|
| 32 |
+
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
|
| 33 |
+
|
| 34 |
+
PREFIX = colorstr("Ultralytics HUB: ")
|
| 35 |
+
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def request_with_credentials(url: str) -> any:
|
| 39 |
+
"""
|
| 40 |
+
Make an AJAX request with cookies attached in a Google Colab environment.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
url (str): The URL to make the request to.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
(any): The response data from the AJAX request.
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
OSError: If the function is not run in a Google Colab environment.
|
| 50 |
+
"""
|
| 51 |
+
if not is_colab():
|
| 52 |
+
raise OSError("request_with_credentials() must run in a Colab environment")
|
| 53 |
+
from google.colab import output # noqa
|
| 54 |
+
from IPython import display # noqa
|
| 55 |
+
|
| 56 |
+
display.display(
|
| 57 |
+
display.Javascript(
|
| 58 |
+
"""
|
| 59 |
+
window._hub_tmp = new Promise((resolve, reject) => {
|
| 60 |
+
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
| 61 |
+
fetch("%s", {
|
| 62 |
+
method: 'POST',
|
| 63 |
+
credentials: 'include'
|
| 64 |
+
})
|
| 65 |
+
.then((response) => resolve(response.json()))
|
| 66 |
+
.then((json) => {
|
| 67 |
+
clearTimeout(timeout);
|
| 68 |
+
}).catch((err) => {
|
| 69 |
+
clearTimeout(timeout);
|
| 70 |
+
reject(err);
|
| 71 |
+
});
|
| 72 |
+
});
|
| 73 |
+
"""
|
| 74 |
+
% url
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
return output.eval_js("_hub_tmp")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def requests_with_progress(method, url, **kwargs):
|
| 81 |
+
"""
|
| 82 |
+
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
| 86 |
+
url (str): The URL to send the request to.
|
| 87 |
+
**kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
(requests.Response): The response object from the HTTP request.
|
| 91 |
+
|
| 92 |
+
Note:
|
| 93 |
+
- If 'progress' is set to True, the progress bar will display the download progress for responses with a known
|
| 94 |
+
content length.
|
| 95 |
+
- If 'progress' is a number then progress bar will display assuming content length = progress.
|
| 96 |
+
"""
|
| 97 |
+
progress = kwargs.pop("progress", False)
|
| 98 |
+
if not progress:
|
| 99 |
+
return requests.request(method, url, **kwargs)
|
| 100 |
+
response = requests.request(method, url, stream=True, **kwargs)
|
| 101 |
+
total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
|
| 102 |
+
try:
|
| 103 |
+
pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
|
| 104 |
+
for data in response.iter_content(chunk_size=1024):
|
| 105 |
+
pbar.update(len(data))
|
| 106 |
+
pbar.close()
|
| 107 |
+
except requests.exceptions.ChunkedEncodingError: # avoid 'Connection broken: IncompleteRead' warnings
|
| 108 |
+
response.close()
|
| 109 |
+
return response
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
|
| 113 |
+
"""
|
| 114 |
+
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
|
| 118 |
+
url (str): The URL to make the request to.
|
| 119 |
+
retry (int, optional): Number of retries to attempt before giving up. Default is 3.
|
| 120 |
+
timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
|
| 121 |
+
thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
|
| 122 |
+
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
|
| 123 |
+
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
|
| 124 |
+
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
|
| 125 |
+
**kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
|
| 129 |
+
"""
|
| 130 |
+
retry_codes = (408, 500) # retry only these codes
|
| 131 |
+
|
| 132 |
+
@TryExcept(verbose=verbose)
|
| 133 |
+
def func(func_method, func_url, **func_kwargs):
|
| 134 |
+
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
|
| 135 |
+
r = None # response
|
| 136 |
+
t0 = time.time() # initial time for timer
|
| 137 |
+
for i in range(retry + 1):
|
| 138 |
+
if (time.time() - t0) > timeout:
|
| 139 |
+
break
|
| 140 |
+
r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
|
| 141 |
+
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
|
| 142 |
+
break
|
| 143 |
+
try:
|
| 144 |
+
m = r.json().get("message", "No JSON message.")
|
| 145 |
+
except AttributeError:
|
| 146 |
+
m = "Unable to read JSON."
|
| 147 |
+
if i == 0:
|
| 148 |
+
if r.status_code in retry_codes:
|
| 149 |
+
m += f" Retrying {retry}x for {timeout}s." if retry else ""
|
| 150 |
+
elif r.status_code == 429: # rate limit
|
| 151 |
+
h = r.headers # response headers
|
| 152 |
+
m = (
|
| 153 |
+
f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
|
| 154 |
+
f"Please retry after {h['Retry-After']}s."
|
| 155 |
+
)
|
| 156 |
+
if verbose:
|
| 157 |
+
LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
|
| 158 |
+
if r.status_code not in retry_codes:
|
| 159 |
+
return r
|
| 160 |
+
time.sleep(2**i) # exponential standoff
|
| 161 |
+
return r
|
| 162 |
+
|
| 163 |
+
args = method, url
|
| 164 |
+
kwargs["progress"] = progress
|
| 165 |
+
if thread:
|
| 166 |
+
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
| 167 |
+
else:
|
| 168 |
+
return func(*args, **kwargs)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class Events:
|
| 172 |
+
"""
|
| 173 |
+
A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
|
| 174 |
+
disabled when sync=False. Run 'yolo settings' to see and update settings YAML file.
|
| 175 |
+
|
| 176 |
+
Attributes:
|
| 177 |
+
url (str): The URL to send anonymous events.
|
| 178 |
+
rate_limit (float): The rate limit in seconds for sending events.
|
| 179 |
+
metadata (dict): A dictionary containing metadata about the environment.
|
| 180 |
+
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
|
| 184 |
+
|
| 185 |
+
def __init__(self):
|
| 186 |
+
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
|
| 187 |
+
self.events = [] # events list
|
| 188 |
+
self.rate_limit = 60.0 # rate limit (seconds)
|
| 189 |
+
self.t = 0.0 # rate limit timer (seconds)
|
| 190 |
+
self.metadata = {
|
| 191 |
+
"cli": Path(sys.argv[0]).name == "yolo",
|
| 192 |
+
"install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
|
| 193 |
+
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
| 194 |
+
"version": __version__,
|
| 195 |
+
"env": ENVIRONMENT,
|
| 196 |
+
"session_id": round(random.random() * 1e15),
|
| 197 |
+
"engagement_time_msec": 1000,
|
| 198 |
+
}
|
| 199 |
+
self.enabled = (
|
| 200 |
+
SETTINGS["sync"]
|
| 201 |
+
and RANK in (-1, 0)
|
| 202 |
+
and not TESTS_RUNNING
|
| 203 |
+
and ONLINE
|
| 204 |
+
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def __call__(self, cfg):
|
| 208 |
+
"""
|
| 209 |
+
Attempts to add a new event to the events list and send events if the rate limit is reached.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
|
| 213 |
+
"""
|
| 214 |
+
if not self.enabled:
|
| 215 |
+
# Events disabled, do nothing
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
# Attempt to add to events
|
| 219 |
+
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
|
| 220 |
+
params = {
|
| 221 |
+
**self.metadata,
|
| 222 |
+
"task": cfg.task,
|
| 223 |
+
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
|
| 224 |
+
}
|
| 225 |
+
if cfg.mode == "export":
|
| 226 |
+
params["format"] = cfg.format
|
| 227 |
+
self.events.append({"name": cfg.mode, "params": params})
|
| 228 |
+
|
| 229 |
+
# Check rate limit
|
| 230 |
+
t = time.time()
|
| 231 |
+
if (t - self.t) < self.rate_limit:
|
| 232 |
+
# Time is under rate limiter, wait to send
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
# Time is over rate limiter, send now
|
| 236 |
+
data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
|
| 237 |
+
|
| 238 |
+
# POST equivalent to requests.post(self.url, json=data)
|
| 239 |
+
smart_request("post", self.url, json=data, retry=0, verbose=False)
|
| 240 |
+
|
| 241 |
+
# Reset events and rate limit timer
|
| 242 |
+
self.events = []
|
| 243 |
+
self.t = t
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# Run below code on hub/utils init -------------------------------------------------------------------------------------
|
| 247 |
+
events = Events()
|
yolov8_model/ultralytics/models/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from .rtdetr import RTDETR
|
| 4 |
+
from .sam import SAM
|
| 5 |
+
from .yolo import YOLO
|
| 6 |
+
|
| 7 |
+
__all__ = "YOLO", "RTDETR", "SAM" # allow simpler import
|
yolov8_model/ultralytics/models/fastsam/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
| 2 |
+
|
| 3 |
+
from .model import FastSAM
|
| 4 |
+
from .predict import FastSAMPredictor
|
| 5 |
+
from .prompt import FastSAMPrompt
|
| 6 |
+
from .val import FastSAMValidator
|
| 7 |
+
|
| 8 |
+
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator"
|
yolov8_model/ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (390 Bytes). View file
|
|
|