| import os |
|
|
| import torch |
| import numpy as np |
| import mmcv |
| import mmengine |
| from mmengine.visualization import Visualizer |
|
|
| from third_parts.sam2.build_sam import build_sam2_video_predictor |
| from mmdet.structures.mask import bitmap_to_polygon |
|
|
| VID_PATH = 'assets/vid_view' |
| MODEL_CKPT = "work_dirs/ckpt/sam2_hiera_large.pt" |
| MODEL_CFG = "sam2_hiera_l.yaml" |
|
|
|
|
| def prepare(): |
| torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| if __name__ == '__main__': |
| prepare() |
| predictor = build_sam2_video_predictor(MODEL_CFG, MODEL_CKPT) |
| inference_state = predictor.init_state(video_path=VID_PATH) |
|
|
| input_point = np.array([[255, 475]]) |
| input_label = np.array([1]) |
|
|
| ann_frame_idx = 0 |
| ann_obj_id = 1 |
|
|
| _frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points( |
| inference_state=inference_state, |
| frame_idx=ann_frame_idx, |
| obj_id=ann_obj_id, |
| points=input_point, |
| labels=input_label, |
| ) |
|
|
| video_segments = {} |
| for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): |
| video_segments[out_frame_idx] = { |
| out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() |
| for i, out_obj_id in enumerate(out_obj_ids) |
| } |
|
|
|
|
|
|
| |
| |
| frame_names = [ |
| p for p in os.listdir(VID_PATH) |
| if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] |
| ] |
| frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) |
|
|
| mmengine.mkdir_or_exist("./result") |
| for idx in range(len(frame_names)): |
| image = mmcv.imread(os.path.join(VID_PATH, frame_names[idx])) |
| visualizer = Visualizer(image=image) |
| masks = video_segments[idx] |
| polygons = [] |
| vis_masks = [] |
| for i, mask in masks.items(): |
| contours, _ = bitmap_to_polygon(mask[0]) |
| polygons.extend(contours) |
|
|
| vis_masks.append(mask[0]) |
| visualizer.draw_polygons(polygons, edge_colors='w', alpha=0.8) |
| visualizer.draw_binary_masks(np.concatenate(vis_masks, axis=0), alphas=0.8) |
|
|
| |
|
|
| result = visualizer.get_image() |
| mmcv.imwrite(result, os.path.join('./result', frame_names[idx])) |
|
|