| import os |
| import gc |
| import cv2 |
| import torch |
| import traceback |
| import numpy as np |
| import gradio as gr |
| from itertools import chain |
| from huggingface_hub import hf_hub_download |
| from segment_anything import SamPredictor, sam_model_registry |
|
|
|
|
| |
|
|
| sam_checkpoint = "sam_vit_h_0.pth" |
| model_type = "vit_h" |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
| def set_predictor(image): |
| """ |
| Creates a Sam predictor object based on a given image and model. |
| """ |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
| sam.to(device=device) |
|
|
| predictor = SamPredictor(sam) |
| predictor.set_image(image) |
|
|
| if device == 'cuda': |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return [image, predictor, 'Done'] |
|
|
|
|
| def get_polygon(points, image, predictor): |
| """ |
| Returns the points of the polygon given a bounding box and a prediction |
| made by Sam. |
| """ |
| points = list(chain.from_iterable(points)) |
|
|
| input_box = np.array(points) |
|
|
| masks, _, _ = predictor.predict( |
| box=input_box[None, :], |
| multimask_output=False, |
| ) |
|
|
| img = masks[0].astype(np.uint8) |
| contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| if len(contours) == 0: |
| return [], img |
|
|
| points = contours[0] |
|
|
| polygon = [] |
| for point in points: |
| for x, y in point: |
| polygon.append([int(x), int(y)]) |
|
|
| mask = np.zeros(image.shape, dtype='uint8') |
| poly = np.array(polygon) |
| cv2.fillPoly(mask, [poly], (0, 255, 0)) |
|
|
| return polygon, mask |
|
|
|
|
| def add_bbox(bbox, evt: gr.SelectData): |
| if bbox[0] == [0, 0]: |
| bbox[0] = [evt.index[0], evt.index[1]] |
| return bbox, bbox |
|
|
| bbox[1] = [evt.index[0], evt.index[1]] |
| return bbox, bbox |
|
|
|
|
| def clear_bbox(bbox): |
| updated_bbox = [[0, 0], [0, 0]] |
| return updated_bbox, updated_bbox |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """ |
| # Instructions |
| 1. Upload the image and press 'Send Image'. |
| 2. Wait until the word 'Done' appears on the 'Status' box. |
| 3. Click on the image where the upper left corner of the bbox should be. |
| 4. Click on the image where the lower right corner of the bbox should be. |
| 5. Check the coordinates using the 'bbox' box. |
| 6. Click on 'Send bounding box'. |
| 7. On the right side you will see the binary mask '\*'. |
| 8. On the lower side you will see the points that made up the polygon '\*'. |
| 9. Click on 'Clear bbox' to send another bounding box and repeat the steps from the thrid step. |
| 10. Repeat steps 3 to 9 until all the segments for this image are done. |
| 11. Click on the right corner of the image to remove it and repeat all the steps with the next |
| image. |
| |
| '\*' If the binary mask is all black and the polygon is an empty list, it means the program did |
| not find any segment in the bbox. Make the bbox a little big bigger if that happens. |
| """) |
|
|
|
|
| image = gr.State() |
| embedding = gr.State() |
| bbox = gr.State([[0, 0], [0, 0]]) |
|
|
| with gr.Row(): |
| input_image = gr.Image(label='Image') |
| mask = gr.Image(label='Mask') |
|
|
| with gr.Row(): |
| with gr.Column(): |
| output_status = gr.Textbox(label='Status') |
| |
| with gr.Column(): |
| predictor_button = gr.Button('Send Image') |
| |
| with gr.Row(): |
| with gr.Column(): |
| bbox_box = gr.Textbox(label="bbox") |
|
|
| with gr.Column(): |
| bbox_button = gr.Button('Clear bbox') |
|
|
| with gr.Row(): |
| with gr.Column(): |
| polygon = gr.Textbox(label='Polygon') |
|
|
| with gr.Column(): |
| points_button = gr.Button('Send bounding box') |
|
|
|
|
| predictor_button.click( |
| set_predictor, |
| input_image, |
| [image, embedding, output_status], |
| ) |
|
|
| points_button.click( |
| get_polygon, |
| [bbox, image, embedding], |
| [polygon, mask], |
| ) |
|
|
| bbox_button.click( |
| clear_bbox, |
| bbox, |
| [bbox, bbox_box], |
| ) |
|
|
| input_image.select( |
| add_bbox, |
| bbox, |
| [bbox, bbox_box] |
| ) |
|
|
|
|
| demo.launch(debug=True) |