| import torch |
| import numpy as np |
| import mmcv |
| from mmengine.visualization import Visualizer |
|
|
| from third_parts.sam2.build_sam import build_sam2 |
| from third_parts.sam2.sam2_image_predictor import SAM2ImagePredictor |
| from mmdet.structures.mask import bitmap_to_polygon |
|
|
| IMG_PATH = 'assets/view.jpg' |
| MODEL_CKPT = "work_dirs/ckpt/sam2_hiera_large.pt" |
| MODEL_CFG = "sam2_hiera_l.yaml" |
|
|
|
|
| def prepare(): |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| if __name__ == '__main__': |
| prepare() |
| sam2_model = build_sam2(MODEL_CFG, MODEL_CKPT, device="cuda") |
| predictor = SAM2ImagePredictor(sam2_model) |
|
|
| image = mmcv.imread(IMG_PATH) |
| predictor.set_image(image) |
| input_point = np.array([[500, 475]]) |
| input_label = np.array([1]) |
|
|
| masks, scores, logits = predictor.predict( |
| point_coords=input_point, |
| point_labels=input_label, |
| multimask_output=True, |
| ) |
| sorted_ind = np.argsort(scores)[::-1] |
| masks = masks[sorted_ind] |
| scores = scores[sorted_ind] |
| logits = logits[sorted_ind] |
|
|
|
|
| visualizer = Visualizer(image=image) |
| masks = masks.astype(bool) |
| masks = masks[0:1] |
| polygons = [] |
| for i, mask in enumerate(masks): |
| contours, _ = bitmap_to_polygon(mask) |
| polygons.extend(contours) |
| visualizer.draw_polygons(polygons, edge_colors='w', alpha=0.8) |
| visualizer.draw_binary_masks(masks, alphas=0.8) |
|
|
| visualizer.draw_points(input_point, 'r', marker='*') |
|
|
| result = visualizer.get_image() |
| |