| import cv2 |
| import numpy as np |
| import supervision as sv |
|
|
| import torch |
| import torchvision |
|
|
| from groundingdino.util.inference import Model |
| from segment_anything import sam_model_registry, SamPredictor |
|
|
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" |
| GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth" |
|
|
| |
| SAM_ENCODER_VERSION = "vit_h" |
| SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth" |
|
|
| |
| grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) |
|
|
| |
| sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH) |
| sam.to(device=DEVICE) |
| sam_predictor = SamPredictor(sam) |
|
|
|
|
| |
| SOURCE_IMAGE_PATH = "./assets/demo2.jpg" |
| CLASSES = ["The running dog"] |
| BOX_THRESHOLD = 0.25 |
| TEXT_THRESHOLD = 0.25 |
| NMS_THRESHOLD = 0.8 |
|
|
|
|
| |
| image = cv2.imread(SOURCE_IMAGE_PATH) |
|
|
| |
| detections = grounding_dino_model.predict_with_classes( |
| image=image, |
| classes=CLASSES, |
| box_threshold=BOX_THRESHOLD, |
| text_threshold=TEXT_THRESHOLD |
| ) |
|
|
| |
| box_annotator = sv.BoxAnnotator() |
| labels = [ |
| f"{CLASSES[class_id]} {confidence:0.2f}" |
| for _, _, confidence, class_id, _ |
| in detections] |
| annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels) |
|
|
| |
| cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) |
|
|
|
|
| |
| print(f"Before NMS: {len(detections.xyxy)} boxes") |
| nms_idx = torchvision.ops.nms( |
| torch.from_numpy(detections.xyxy), |
| torch.from_numpy(detections.confidence), |
| NMS_THRESHOLD |
| ).numpy().tolist() |
|
|
| detections.xyxy = detections.xyxy[nms_idx] |
| detections.confidence = detections.confidence[nms_idx] |
| detections.class_id = detections.class_id[nms_idx] |
|
|
| print(f"After NMS: {len(detections.xyxy)} boxes") |
|
|
| |
| def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: |
| sam_predictor.set_image(image) |
| result_masks = [] |
| for box in xyxy: |
| masks, scores, logits = sam_predictor.predict( |
| box=box, |
| multimask_output=True |
| ) |
| index = np.argmax(scores) |
| result_masks.append(masks[index]) |
| return np.array(result_masks) |
|
|
|
|
| |
| detections.mask = segment( |
| sam_predictor=sam_predictor, |
| image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), |
| xyxy=detections.xyxy |
| ) |
|
|
| |
| box_annotator = sv.BoxAnnotator() |
| mask_annotator = sv.MaskAnnotator() |
| labels = [ |
| f"{CLASSES[class_id]} {confidence:0.2f}" |
| for _, _, confidence, class_id, _ |
| in detections] |
| annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) |
| annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) |
|
|
| |
| cv2.imwrite("grounded_sam_annotated_image.jpg", annotated_image) |
|
|