| import gradio as gr |
| import torch |
| from PIL import Image |
| import numpy as np |
| from sam2 import build_sam2, SamPredictor |
| from huggingface_hub import hf_hub_download |
|
|
| |
| model_path = hf_hub_download(repo_id="facebook/sam2-hiera-large", filename="sam2_hiera_large.pth") |
|
|
| |
| device = "cpu" |
| model = build_sam2(checkpoint=model_path).to(device) |
| predictor = SamPredictor(model) |
|
|
| def segment_image(input_image, x, y): |
| |
| input_image = np.array(input_image) |
| |
| |
| predictor.set_image(input_image) |
| |
| |
| input_point = np.array([[x, y]]) |
| input_label = np.array([1]) |
|
|
| |
| masks, _, _ = predictor.predict( |
| point_coords=input_point, |
| point_labels=input_label, |
| multimask_output=False, |
| ) |
|
|
| |
| mask = masks[0] |
| mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
| |
| |
| result = Image.composite(Image.fromarray(input_image), Image.new('RGB', mask_image.size, 'black'), mask_image) |
| |
| return result |
|
|
| |
| iface = gr.Interface( |
| fn=segment_image, |
| inputs=[ |
| gr.Image(type="pil"), |
| gr.Slider(0, 1000, label="X coordinate"), |
| gr.Slider(0, 1000, label="Y coordinate") |
| ], |
| outputs=gr.Image(type="pil"), |
| title="SAM2 Image Segmentation", |
| description="Upload an image and select a point to segment. Adjust X and Y coordinates to refine the selection." |
| ) |
|
|
| |
| iface.launch() |