| import cv2 |
| import torch |
|
|
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
| from sam2.build_sam import build_sam2 |
| from sam2.build_sam import build_sam2_video_predictor |
| import sam2 |
| from PIL import Image |
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| import argparse |
|
|
| def area(mask): |
| if mask.size == 0: return 0 |
| return np.count_nonzero(mask) / mask.size |
|
|
| def show_mask(mask, ax, obj_id=None, random_color=False, borders = True, alpha=0.5): |
| if random_color: |
| color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0) |
| else: |
| color = np.array([30/255, 144/255, 255/255, alpha]) |
| if not random_color and obj_id is not None: |
| color = np.array([*plt.get_cmap("tab10")(obj_id)[:3], alpha]) |
| h, w = mask.shape[-2:] |
| mask = mask.astype(np.uint8) |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| if borders: |
| import cv2 |
| contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| |
| contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
| mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) |
| ax.imshow(mask_image) |
|
|
| def area(mask): |
| if mask.size == 0: return 0 |
| return np.count_nonzero(mask) / mask.size |
|
|
| def nms_bbox_removal(boxes_xyxy, iou_thresh=0.25 ): |
| remove_indices = [] |
| for i, box in enumerate(boxes_xyxy): |
| for j in range(i+1, len(boxes_xyxy)): |
| box2 = boxes_xyxy[j] |
| iou1 = compute_iou(box, box2) |
| iou2 = compute_iou(box2, box) |
| if iou1 > iou_thresh or iou2 > iou_thresh: |
| if iou1 > iou2: |
| remove_indices.append(j) |
| else: |
| remove_indices.append(i) |
| return [box for i, box in enumerate(boxes_xyxy) if i not in remove_indices] |
|
|
| def load_SAM2(ckpt_path, model_cfg_path): |
| if torch.cuda.is_available(): |
| print("Using CUDA") |
| device = "cuda" |
| else: |
| print("CUDA device not found, using CPU instead") |
| device = "cpu" |
| sam2 = build_sam2(model_cfg_path, ckpt_path, device=device, apply_postprocessing=False) |
| return sam2 |
|
|
| def compute_iou(box1, box2): |
| |
| x1, y1, x2, y2 = box1 |
| x3, y3, x4, y4 = box2 |
| x5, y5 = max(x1, x3), max(y1, y3) |
| x6, y6 = min(x2, x4), min(y2, y4) |
| if x5 >= x6 or y5 >= y6: |
| return 0 |
| intersection = (x6 - x5) * (y6 - y5) |
| union = (x2 - x1) * (y2 - y1) |
| return intersection / union |
|
|
| def show_anns(anns, color=None, borders=True): |
| if len(anns) == 0: |
| return |
| sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
| ax = plt.gca() |
| ax.set_autoscale_on(False) |
|
|
| img = np.ones((sorted_anns[0]['segmentation'].squeeze().shape[0], sorted_anns[0]['segmentation'].squeeze().shape[1], 4)) |
| img[:, :, 3] = 0 |
| for ann in sorted_anns: |
| m = ann['segmentation'].squeeze() |
| if color is None: |
| color_mask = np.concatenate([np.random.random(3), [0.75]]) |
| else: |
| color_mask = color |
| img[m] = color_mask |
| if borders: |
| import cv2 |
| contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| |
| contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] |
| cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=2) |
|
|
| ax.imshow(img) |
|
|
| def build_sam2_predictor(checkpoint="checkpoints/sam2_hiera_large.pt", model_cfg="sam2_hiera_l"): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| video_predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=device, apply_postprocessing=False) |
| return video_predictor |
|
|
| def load_masks(video_predictor, query_images, support_image, support_masks, offload_video_to_cpu=True, offload_state_to_cpu=True, verbose=False): |
| ''' |
| video_predictor: sam2 predictor |
| query_images: list of np.array of shape (H, W, 3) |
| support_image: np.array of shape (H, W, 3) |
| support_masks: list of np.array of shape (H, W) |
| offload_video_to_cpu: for long video sequences, offload the video to the CPU to save GPU memory |
| offload_state_to_cpu: save GPU memory by offloading the state to the CPU |
| ''' |
| query_images.insert(0, support_image) |
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| state = video_predictor.init_state(None, image_inputs=query_images, async_loading_frames=False, offload_video_to_cpu=offload_video_to_cpu, offload_state_to_cpu=offload_state_to_cpu, verbose=verbose) |
| video_predictor.reset_state(state) |
| for i, patch_mask in enumerate(support_masks): |
| ann_frame_idx = 0 |
| ann_obj_id = i |
| patch_mask = np.array(patch_mask, dtype=np.uint8) |
| patch_mask = cv2.resize(patch_mask, (1024, 1024)) |
| _, _, _ = video_predictor.add_new_mask( |
| inference_state=state, |
| frame_idx=ann_frame_idx, |
| obj_id=ann_obj_id, |
| mask=patch_mask, |
| ) |
| return state |
|
|
| def propagate_masks(video_predictor, state, verbose=False): |
| """ |
| returns: list[dict] with keys 'obj_ids', 'segmentation', 'area' |
| list['segmentation']: np.array of shape (H, W) with dtype bool |
| """ |
| frame_info = [] |
| |
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| for _, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(state, verbose=verbose): |
| out_mask_logits = (out_mask_logits>0).cpu().numpy().squeeze() |
| if out_mask_logits.ndim == 2: |
| out_mask_logits = np.expand_dims(out_mask_logits, axis=0) |
| frame_info.append({'obj_ids': out_obj_ids, 'segmentation': out_mask_logits, 'area': area(out_mask_logits)}) |
| return frame_info |
|
|
| def show_video_masks(image, frame_info): |
| img_resized = cv2.resize(image, (1024, 1024)) |
| plt.imshow(img_resized) |
| for obj_ids, mask in zip(frame_info['obj_ids'], frame_info['masks']): |
| mask = cv2.resize(mask.astype(np.uint8), (1024, 1024)) |
| show_mask(mask, plt.gca(), obj_id=obj_ids, borders=True, alpha=0.75) |
| plt.axis('off') |
| plt.show() |
|
|
| def get_parser(inputs): |
| parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") |
| parser.add_argument( |
| "--config-file", |
| default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml", |
| metavar="FILE", |
| help="path to config file", |
| ) |
| parser.add_argument( |
| "--opts", |
| help="Modify config options using the command-line 'KEY VALUE' pairs", |
| default=[], |
| nargs=argparse.REMAINDER, |
| ) |
| args = parser.parse_args(inputs) |
| return args |
|
|
| def auto_segment_SAM(boxes_xyxy, img, iou_thresh=0.9, stability_score_thresh=0.95, min_mask_region_area=10000, verbose=False): |
| checkpoint = "../../checkpoints/sam2_hiera_large.pt" |
| model_cfg = "../../sam2_configs/sam2_hiera_l.yaml" |
| sam2 = load_SAM2(checkpoint, model_cfg) |
| auto_mask_predictor = SAM2AutomaticMaskGenerator(sam2, |
| points_per_batch=128, |
| pred_iou_thresh=iou_thresh, |
| stability_score_thresh=stability_score_thresh, |
| min_mask_region_area=min_mask_region_area, |
| multimask_output=True) |
| masks_list = [] |
| for box_xyxy in boxes_xyxy: |
| wing = img[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])] |
| mask = auto_mask_predictor.generate(wing) |
| |
| |
| |
| if verbose: |
| plt.imshow(wing) |
| show_anns(mask) |
| |
| plt.axis('off') |
| plt.show() |
| |
| binary_masks = [e['segmentation'] for e in mask] |
|
|
| for e in binary_masks: |
| new_mask = np.zeros((img.shape[0], img.shape[1]), dtype=bool) |
| new_mask[int(box_xyxy[1]):int(box_xyxy[3]), int(box_xyxy[0]):int(box_xyxy[2])] = e |
| new_mask_dict = { |
| 'segmentation': new_mask, |
| 'area': area(new_mask) |
| } |
| masks_list.append(new_mask_dict) |
| return masks_list |
|
|
| def show_masks(masks_list, img, verbose=True, imshow=True, grey=False): |
| if imshow: |
| if grey: |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| plt.imshow(img, cmap='gray') |
| else: |
| plt.imshow(img) |
| plt.axis('off') |
| show_anns(masks_list) |
| if verbose: |
| plt.show() |
|
|
| def show_individual_masks(masks_list, img): |
| for mask in masks_list: |
| plt.imshow(img) |
| plt.axis('off') |
| show_anns([mask]) |
| plt.show() |