| |
|
|
| import cv2 |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from typing import List, Dict, Tuple, Optional |
| from ultralytics import YOLO |
|
|
| def perform_detection( |
| yolo_model: YOLO, |
| frame: np.ndarray, |
| conf_threshold: float=0.5 |
| ) -> Optional[List[Dict]]: |
| """ |
| Runs the YOLO model inference on a single frame. |
| """ |
| if frame is None: |
| print("Error: Input frame is None in perform_detection.") |
| return None |
| try: |
| |
| results = yolo_model.predict(source=frame, conf=conf_threshold, verbose=False) |
| return results |
| except Exception as e: |
| print(f"Error during model prediction: {e}") |
| return None |
|
|
| def create_motion_mask(frame, threshold=25): |
| """ |
| Creates a simple motion mask from an image. |
| For the demo, we'll use a basic thresholding approach. |
| """ |
| |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| |
| |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) |
| |
| |
| thresh = cv2.adaptiveThreshold( |
| blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, |
| cv2.THRESH_BINARY_INV, 11, threshold |
| ) |
| |
| |
| kernel = np.ones((3, 3), np.uint8) |
| mask = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) |
| |
| return mask |
|
|
| def postprocess_results( |
| results: Optional[List[Dict]], |
| model_class_names: Dict[int, str], |
| mask: Optional[np.ndarray] = None |
| ) -> List[Dict]: |
| """ |
| Extracts information from detection results. |
| If a mask is provided, only keeps detections that overlap with the mask. |
| """ |
| detections_list = [] |
| if results is None or not results: |
| return detections_list |
|
|
| try: |
| boxes = results[0].boxes |
| except (IndexError, AttributeError) as e: |
| print(f"Warning: Could not access boxes in results: {e}") |
| return detections_list |
|
|
| for box in boxes: |
| try: |
| |
| xyxy = box.xyxy[0].cpu().numpy().astype(int) |
| x1, y1, x2, y2 = xyxy |
|
|
| |
| if mask is not None: |
| |
| center_x = (x1 + x2) // 2 |
| center_y = (y1 + y2) // 2 |
| |
| |
| |
| y1_safe = max(0, min(y1, mask.shape[0]-1)) |
| y2_safe = max(0, min(y2, mask.shape[0]-1)) |
| x1_safe = max(0, min(x1, mask.shape[1]-1)) |
| x2_safe = max(0, min(x2, mask.shape[1]-1)) |
| |
| |
| box_region = mask[y1_safe:y2_safe, x1_safe:x2_safe] |
| |
| |
| if box_region.size > 0: |
| mask_coverage = np.sum(box_region > 0) / box_region.size |
| else: |
| mask_coverage = 0 |
| |
| |
| if not (0 <= center_y < mask.shape[0] and 0 <= center_x < mask.shape[1] and |
| (mask[center_y, center_x] > 0 or mask_coverage > 0.5)): |
| continue |
|
|
| |
| conf = float(box.conf[0].cpu().numpy()) |
|
|
| |
| cls_id = int(box.cls[0].cpu().numpy()) |
| class_name = model_class_names.get(cls_id, f"Unknown Class {cls_id}") |
|
|
| |
| detections_list.append({ |
| 'class_name': class_name, |
| 'confidence': conf, |
| 'bbox_xyxy': [x1, y1, x2, y2] |
| }) |
| except Exception as e: |
| print(f"Error processing a detection box: {e}") |
| continue |
| return detections_list |
|
|
| def draw_detections( |
| frame: np.ndarray, |
| detections: List[Dict], |
| mask: Optional[np.ndarray] = None |
| ) -> np.ndarray: |
| """ |
| Draws bounding boxes and labels on the frame. |
| If mask is provided, overlays it on the frame. |
| """ |
| output_frame = frame.copy() |
| |
| |
| if mask is not None and mask.shape[0] > 0 and mask.shape[1] > 0: |
| |
| mask_overlay = np.zeros_like(output_frame) |
| mask_overlay[mask > 0] = [0, 100, 0] |
| |
| |
| output_frame = cv2.addWeighted(output_frame, 0.7, mask_overlay, 0.3, 0) |
| |
| |
| color = (0, 255, 0) |
| font_scale = 1.2 |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| for detection in detections: |
| try: |
| x1, y1, x2, y2 = detection['bbox_xyxy'] |
| class_name = detection['class_name'] |
| conf = detection['confidence'] |
|
|
| |
| cv2.rectangle(output_frame, (x1, y1), (x2, y2), color, 5) |
|
|
| |
| label = f"{class_name}: {conf:.2f}" |
|
|
| |
| (label_width, label_height), baseline = cv2.getTextSize(label, font, font_scale, 3) |
| label_ymin = max(y1, label_height + 10) |
|
|
| |
| cv2.rectangle(output_frame, |
| (x1, label_ymin - label_height - 10), |
| (x1 + label_width, label_ymin - baseline), |
| color, |
| cv2.FILLED) |
|
|
| |
| cv2.putText(output_frame, |
| label, |
| (x1, label_ymin - 5), |
| font, |
| font_scale, |
| (255, 255, 255), |
| 3) |
| except Exception as e: |
| continue |
| return output_frame |
|
|
| def load_yolo_model(model_path): |
| """ |
| Loads the YOLO model from the specified path. |
| """ |
| print("Loading the YOLO model...") |
| try: |
| model = YOLO(model_path) |
| class_names = model.names |
| print(f"Model loaded with {len(class_names)} classes!") |
| return model, class_names |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return None, None |
|
|
| def load_image(image_path): |
| """ |
| Loads an image from the specified path. |
| """ |
| print(f"Opening image: {image_path}") |
| image = cv2.imread(image_path) |
| if image is None: |
| print(f"Error: Could not read image file '{image_path}'.") |
| return image |
|
|
| def load_or_create_mask(image, mask_path=None): |
| """ |
| Either loads a mask from disk or creates a new one from the image. |
| """ |
| if mask_path and os.path.exists(mask_path): |
| print(f"Loading mask: {mask_path}") |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
| else: |
| print("Creating mask from image...") |
| mask = create_motion_mask(image) |
| |
| return mask |
|
|
| def display_results(output_frame, detections, mask=None): |
| """ |
| Displays detection results and saves the output image. |
| """ |
| |
| print("\n--- Insects Detected ---") |
| if detections: |
| for i, obj in enumerate(detections, 1): |
| print(f"{i}. {obj['class_name']} (confidence: {obj['confidence']:.2f})") |
| else: |
| print("No insects detected.") |
| |
| |
| if mask is not None: |
| cv2.imwrite("motion_mask.png", mask) |
| print("Motion mask saved to: motion_mask.png") |
| |
| |
| output_rgb = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB) |
| |
| |
| plt.figure(figsize=(10, 8)) |
| plt.imshow(output_rgb) |
| plt.title("Insect Detection Results") |
| plt.axis('off') |
| plt.show() |
| |
| |
| result_path = "detection_result.jpg" |
| cv2.imwrite(result_path, output_frame) |
| print(f"Result saved to: {result_path}") |
|
|
| import os |