| import os, shutil |
| import numpy as np |
| from PIL import Image |
| from typing import Literal, Any, Union, Generic, List |
| from pydantic import BaseModel |
| from sam2.build_sam import build_sam2, build_sam2_video_predictor |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
| from sam2.utils.misc import variant_to_config_mapping |
| from sam2.utils.visualization import show_masks |
| from ffmpeg_extractor import extract_frames, logger |
| from visualizer import mask_to_xyxy |
| from toolbox.vid_utils import VidInfo, VidReader |
| from toolbox.mask_encoding import b64_mask_encode |
|
|
| |
|
|
| variant_checkpoints_mapping = { |
| "tiny": "checkpoints/sam2_hiera_tiny.pt", |
| "small": "checkpoints/sam2_hiera_small.pt", |
| "base_plus": "checkpoints/sam2_hiera_base_plus.pt", |
| "large": "checkpoints/sam2_hiera_large.pt", |
| } |
|
|
|
|
| class bbox_xyxy(BaseModel): |
| x0: Union[int, float] |
| y0: Union[int, float] |
| x1: Union[int, float] |
| y1: Union[int, float] |
|
|
|
|
| class point_xy(BaseModel): |
| x: Union[int, float] |
| y: Union[int, float] |
|
|
|
|
| def load_sam_image_model( |
| |
| variant: Literal["tiny", "small", "base_plus", "large"], |
| device: str = "cpu", |
| auto_mask_gen: bool = False, |
| ) -> SAM2ImagePredictor: |
| model = build_sam2( |
| config_file=variant_to_config_mapping[variant], |
| ckpt_path=variant_checkpoints_mapping[variant], |
| device=device, |
| ) |
| return ( |
| SAM2AutomaticMaskGenerator(model) |
| if auto_mask_gen |
| else SAM2ImagePredictor(sam_model=model) |
| ) |
|
|
|
|
| def load_sam_video_model( |
| variant: Literal["tiny", "small", "base_plus", "large"] = "small", |
| device: str = "cpu", |
| ) -> Any: |
| return build_sam2_video_predictor( |
| config_file=variant_to_config_mapping[variant], |
| ckpt_path=variant_checkpoints_mapping[variant], |
| device=device, |
| ) |
|
|
|
|
| def run_sam_im_inference( |
| model: Any, |
| image: Image.Image, |
| points: Union[List[point_xy], List[dict]] = [], |
| point_labels: List[int] = [], |
| bboxes: Union[List[bbox_xyxy], List[dict]] = [], |
| get_pil_mask: bool = False, |
| b64_encode_mask: bool = False, |
| ): |
| """returns a list of np masks, each with the shape (h,w) and dtype uint8""" |
| assert ( |
| points or bboxes |
| ), f"SAM2 Image Inference must have either bounding boxes or points. Neither were provided." |
| if points: |
| assert len(points) == len( |
| point_labels |
| ), f"{len(points)} points provided but {len(point_labels)} labels given." |
|
|
| |
| |
| has_multi = False |
| if points and bboxes: |
| has_multi = True |
| elif points and len(list(set(point_labels))) > 1: |
| has_multi = True |
| elif bboxes and len(bboxes) > 1: |
| has_multi = True |
|
|
| |
| bboxes = ( |
| [bbox_xyxy(**bbox) if isinstance(bbox, dict) else bbox for bbox in bboxes] |
| if bboxes |
| else [] |
| ) |
| points = ( |
| [point_xy(**p) if isinstance(p, dict) else p for p in points] if points else [] |
| ) |
|
|
| |
| image = np.array(image.convert("RGB")) |
| model.set_image(image) |
|
|
| box_coords = ( |
| np.array([[b.x0, b.y0, b.x1, b.y1] for b in bboxes]) if bboxes else None |
| ) |
| point_coords = np.array([[p.x, p.y] for p in points]) if points else None |
| point_labels = np.array(point_labels) if point_labels else None |
|
|
| masks, scores, _ = model.predict( |
| box=box_coords, |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=False, |
| ) |
| |
|
|
| if get_pil_mask: |
| return show_masks(image, masks, scores=None, display_image=False) |
| else: |
| output_masks = [] |
| for i, mask in enumerate(masks): |
| if mask.ndim > 2: |
| |
| output_masks.append(mask.squeeze().astype(np.uint8)) |
|
|
| |
| |
| |
| |
| else: |
| |
| output_masks.append(mask.squeeze().astype(np.uint8)) |
| return ( |
| [b64_mask_encode(m).decode("ascii") for m in output_masks] |
| if b64_encode_mask |
| else output_masks |
| ) |
|
|
|
|
| def unpack_masks( |
| masks_generator, |
| frame_wh: tuple, |
| drop_mask: bool = False, |
| ): |
| """return a list of detections in Miro's format given a SAM2 mask generator""" |
| w, h = frame_wh |
| detections = [] |
| for frame_idx, tracker_ids, mask_logits in masks_generator: |
| masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| for id, mask in zip(tracker_ids, masks): |
| mask = mask.squeeze().astype(np.uint8) |
| xyxy = mask_to_xyxy(mask) |
| if not xyxy: |
| |
| continue |
| x0, y0, x1, y1 = xyxy |
| det = { |
| "frame": frame_idx, |
| "track_id": id, |
| "x": x0 / w, |
| "y": y0 / h, |
| "w": (x1 - x0) / w, |
| "h": (y1 - y0) / h, |
| "conf": 1, |
| } |
| if not drop_mask: |
| det["mask_b64"] = b64_mask_encode(mask).decode("ascii") |
| detections.append(det) |
| return detections |
|
|
|
|
| def run_sam_video_inference( |
| model: Any, |
| video_path: str, |
| masks: np.ndarray, |
| device: str = "cpu", |
| sample_fps: int = None, |
| every_x: int = None, |
| do_tidy_up: bool = False, |
| drop_mask: bool = True, |
| async_frame_load: bool = False, |
| ref_frame_idx: int = 0, |
| ): |
| |
| |
| |
| l_frames_fp = extract_frames( |
| video_path, |
| fps=sample_fps, |
| every_x=every_x, |
| overwrite=True, |
| im_name_pattern="%05d.jpg", |
| ) |
| vframes_dir = os.path.dirname(l_frames_fp[0]) |
| vinfo = VidInfo(video_path) |
| vr = VidReader(video_path, use_imageio=True) |
| w = vinfo["frame_width"] |
| h = vinfo["frame_height"] |
|
|
| inference_state = model.init_state( |
| video_path=vframes_dir, device=device, async_loading_frames=async_frame_load |
| ) |
| for mask_idx, mask in enumerate(masks): |
| _, object_ids, mask_logits = model.add_new_mask( |
| inference_state=inference_state, |
| frame_idx=ref_frame_idx, |
| obj_id=mask_idx, |
| mask=mask, |
| ) |
| |
| logger.debug( |
| f"adding mask {mask_idx} of shape {mask.shape} for frame {ref_frame_idx}, xyxy: {mask_to_xyxy(mask)}" |
| ) |
|
|
| |
| logger.debug(f"model initiated with mask_logits of shape {mask_logits.shape}") |
| logger.debug(f"model initiated with object_ids of len {len(object_ids)}") |
| init_masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8) |
| init_masks = [m.squeeze() for m in init_masks] |
| |
| |
| |
| |
| |
| |
| |
|
|
| masks_generator = model.propagate_in_video(inference_state) |
| detections = unpack_masks( |
| masks_generator, |
| drop_mask=drop_mask, |
| frame_wh=(w, h), |
| ) |
|
|
| if ref_frame_idx != 0: |
| logger.debug(f"propagating in reverse now from {ref_frame_idx}") |
| |
| |
| masks_generator = model.propagate_in_video(inference_state, reverse=True) |
| detections += unpack_masks( |
| masks_generator, |
| drop_mask=drop_mask, |
| frame_wh=(w, h), |
| ) |
|
|
| if do_tidy_up: |
| |
| shutil.rmtree(vframes_dir) |
|
|
| return detections |
|
|