| """ |
| SAM2 Interaction Tools |
| Handles SAM2 mask generation with user clicks |
| """ |
|
|
| import sys |
| sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2") |
|
|
| import numpy as np |
| from PIL import Image |
| from .base_segmenter import BaseSegmenter |
| from .painter import mask_painter, point_painter |
|
|
|
|
| mask_color = 3 |
| mask_alpha = 0.7 |
| contour_color = 1 |
| contour_width = 5 |
| point_color_ne = 8 |
| point_color_ps = 50 |
| point_alpha = 0.9 |
| point_radius = 15 |
|
|
|
|
| class SamControler: |
| def __init__(self, SAM_checkpoint, model_type, device): |
| """ |
| Initialize SAM controller |
| |
| Args: |
| SAM_checkpoint: Path to SAM2 checkpoint |
| model_type: SAM2 model config file |
| device: Device to run on |
| """ |
| self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) |
| self.device = device |
| |
| def first_frame_click(self, image: np.ndarray, points: np.ndarray, |
| labels: np.ndarray, multimask=True, mask_color=3): |
| """ |
| Generate mask from clicks on first frame |
| |
| Args: |
| image: np.ndarray, (H, W, 3), RGB image |
| points: np.ndarray, (N, 2), [x, y] coordinates |
| labels: np.ndarray, (N,), 1 for positive, 0 for negative |
| multimask: bool, whether to generate multiple masks |
| mask_color: int, color ID for mask overlay |
| |
| Returns: |
| mask: np.ndarray, (H, W), binary mask |
| logit: np.ndarray, (H, W), mask logits |
| painted_image: PIL.Image, visualization with mask and points |
| """ |
| |
| neg_flag = labels[-1] |
| |
| if neg_flag == 1: |
| |
| prompts = { |
| 'point_coords': points, |
| 'point_labels': labels, |
| } |
| masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) |
| mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] |
| |
| |
| prompts = { |
| 'point_coords': points, |
| 'point_labels': labels, |
| 'mask_input': logit[None, :, :] |
| } |
| masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) |
| mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] |
| else: |
| prompts = { |
| 'point_coords': points, |
| 'point_labels': labels, |
| } |
| masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) |
| mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] |
| |
| |
| painted_image = mask_painter( |
| image, |
| mask.astype('uint8'), |
| mask_color, |
| mask_alpha, |
| contour_color, |
| contour_width |
| ) |
| |
| |
| positive_points = np.squeeze(points[np.argwhere(labels > 0)], axis=1) |
| if len(positive_points) > 0: |
| painted_image = point_painter( |
| painted_image, |
| positive_points, |
| point_color_ne, |
| point_alpha, |
| point_radius, |
| contour_color, |
| contour_width |
| ) |
| |
| |
| negative_points = np.squeeze(points[np.argwhere(labels < 1)], axis=1) |
| if len(negative_points) > 0: |
| painted_image = point_painter( |
| painted_image, |
| negative_points, |
| point_color_ps, |
| point_alpha, |
| point_radius, |
| contour_color, |
| contour_width |
| ) |
| |
| painted_image = Image.fromarray(painted_image) |
| |
| return mask, logit, painted_image |
|
|