| from .tracker.byte_tracker import BYTETracker |
| import cv2 |
| import numpy as np |
|
|
| class ByteTrack(object): |
| def __init__(self, detector, min_box_area=10): |
| self.min_box_area = min_box_area |
|
|
| self.rgb_means = (0.485, 0.456, 0.406) |
| self.std = (0.229, 0.224, 0.225) |
|
|
| self.detector = detector |
| self.input_shape = tuple(detector.model.get_inputs()[0].shape[2:]) |
| self.tracker = BYTETracker(frame_rate=30) |
|
|
| def inference(self, image, conf_thresh=0.25, classes=None): |
| |
| dets, image_info = self.detector.detect(image, conf_thres=conf_thresh, input_shape=self.input_shape, classes=classes) |
| |
| class_ids=[] |
| ids=[] |
| bboxes=[] |
| scores=[] |
|
|
| if isinstance(dets, np.ndarray) and len(dets) > 0: |
| class_ids = dets[:, -1].tolist() |
| bboxes, ids, scores = self._tracker_update( |
| dets, |
| image_info, |
| ) |
| |
| |
| |
| |
| |
| |
|
|
| |
| return bboxes, ids, scores, class_ids |
| |
| def get_id_color(self, index): |
| temp_index = abs(int(index)) * 3 |
| color = ((37 * temp_index) % 255, (17 * temp_index) % 255, |
| (29 * temp_index) % 255) |
| return color |
|
|
| def draw_tracking_info( |
| self, |
| image, |
| tlwhs, |
| ids, |
| scores, |
| frame_id=0, |
| elapsed_time=0., |
| ): |
| text_scale = 1.5 |
| text_thickness = 2 |
| line_thickness = 2 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for index, tlwh in enumerate(tlwhs): |
| x1, y1 = int(tlwh[0]), int(tlwh[1]) |
| x2, y2 = x1 + int(tlwh[2]), y1 + int(tlwh[3]) |
| color = self.get_id_color(ids[index]) |
| cv2.rectangle(image, (x1, y1), (x2, y2), color, line_thickness) |
|
|
| text = str(ids[index]) |
| cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, |
| text_scale, (0, 0, 0), text_thickness + 3) |
| cv2.putText(image, text, (x1, y1 - 5), cv2.FONT_HERSHEY_PLAIN, |
| text_scale, (255, 255, 255), text_thickness) |
| return image |
|
|
| def _tracker_update(self, dets, image_info): |
| online_targets = [] |
| if dets is not None: |
| online_targets = self.tracker.update( |
| dets[:, :-1], |
| [image_info['height'], image_info['width']], |
| [image_info['height'], image_info['width']], |
| ) |
| online_tlwhs = [] |
| online_ids = [] |
| online_scores = [] |
| for online_target in online_targets: |
| tlwh = online_target.tlwh |
| track_id = online_target.track_id |
| vertical = tlwh[2] / tlwh[3] > 1.6 |
| if tlwh[2] * tlwh[3] > self.min_box_area and not vertical: |
| online_tlwhs.append(tlwh) |
| online_ids.append(track_id) |
| online_scores.append(online_target.score) |
|
|
| return online_tlwhs, online_ids, online_scores |