| import gradio as gr |
| import torch |
| import numpy as np |
| from PIL import Image |
| import cv2 |
| from segment_anything import sam_model_registry, SamPredictor |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection |
| import supervision as sv |
| import os |
| import urllib.request |
|
|
| |
| SAM_CHECKPOINT = "sam_vit_h_4b8939.pth" |
| SAM_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
|
| if not os.path.exists(SAM_CHECKPOINT): |
| print(f"Downloading SAM checkpoint...") |
| urllib.request.urlretrieve(SAM_CHECKPOINT_URL, SAM_CHECKPOINT) |
| print(f"SAM checkpoint downloaded!") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") |
| grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( |
| "IDEA-Research/grounding-dino-tiny" |
| ).to(device) |
|
|
| |
| sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT) |
| sam.to(device=device) |
| sam_predictor = SamPredictor(sam) |
|
|
| def process_image(image, text_prompt, box_threshold, text_threshold, quality): |
| """ |
| Process image with Grounded SAM |
| """ |
| try: |
| |
| if quality == "Low": |
| max_size = 800 |
| elif quality == "Medium": |
| max_size = 1024 |
| else: |
| max_size = 1920 |
| |
| |
| h, w = image.shape[:2] |
| if max(h, w) > max_size: |
| scale = max_size / max(h, w) |
| new_h, new_w = int(h * scale), int(w * scale) |
| image = cv2.resize(image, (new_w, new_h)) |
| |
| |
| pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
| |
| |
| inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt").to(device) |
| |
| with torch.no_grad(): |
| outputs = grounding_dino_model(**inputs) |
| |
| |
| results = grounding_dino_processor.post_process_grounded_object_detection( |
| outputs, |
| inputs.input_ids, |
| box_threshold=box_threshold, |
| text_threshold=text_threshold, |
| target_sizes=[pil_image.size[::-1]] |
| )[0] |
| |
| |
| boxes = results["boxes"].cpu().numpy() |
| labels = results["labels"] |
| |
| if len(boxes) == 0: |
| return image, "No objects detected. Try adjusting the thresholds or text prompt." |
| |
| |
| boxes_xyxy = boxes |
| |
| |
| sam_predictor.set_image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
| |
| masks = [] |
| for box in boxes_xyxy: |
| mask, _, _ = sam_predictor.predict( |
| box=box, |
| multimask_output=False |
| ) |
| masks.append(mask[0]) |
| |
| |
| result_image = image.copy() |
| |
| |
| for i, mask in enumerate(masks): |
| color = np.random.randint(0, 255, 3).tolist() |
| result_image[mask] = result_image[mask] * 0.5 + np.array(color) * 0.5 |
| |
| |
| for i, (box, label) in enumerate(zip(boxes_xyxy, labels)): |
| x1, y1, x2, y2 = map(int, box) |
| color = np.random.randint(0, 255, 3).tolist() |
| cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2) |
| cv2.putText(result_image, label, (x1, y1-10), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) |
| |
| metadata = f"β
Detected {len(boxes)} objects: {', '.join(labels)}" |
| return result_image, metadata |
| |
| except Exception as e: |
| return image, f"β Error: {str(e)}" |
|
|
| |
| with gr.Blocks(title="Grounded SAM") as demo: |
| gr.Markdown("# π― Grounded SAM - Object Detection & Segmentation") |
| gr.Markdown("Upload an image and describe what you want to detect (e.g., 'fish', 'all fish', 'person').") |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(label="Input Image", type="numpy") |
| text_prompt = gr.Textbox( |
| label="Text Prompt", |
| placeholder="e.g., 'fish', 'person', 'car'", |
| value="fish" |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| box_threshold = gr.Slider( |
| minimum=0.0, maximum=1.0, value=0.35, step=0.05, |
| label="Box Threshold (detection confidence)" |
| ) |
| text_threshold = gr.Slider( |
| minimum=0.0, maximum=1.0, value=0.25, step=0.05, |
| label="Text Threshold (text matching confidence)" |
| ) |
| quality = gr.Radio( |
| choices=["Low", "Medium", "High"], |
| value="Medium", |
| label="Processing Quality" |
| ) |
| |
| submit_btn = gr.Button("π Process Image", variant="primary") |
| |
| with gr.Column(): |
| output_image = gr.Image(label="Output with Masks & Boxes", type="numpy") |
| output_metadata = gr.Textbox(label="Detection Metadata", lines=3) |
| |
| submit_btn.click( |
| fn=process_image, |
| inputs=[input_image, text_prompt, box_threshold, text_threshold, quality], |
| outputs=[output_image, output_metadata] |
| ) |
| |
| gr.Examples( |
| examples=[ |
| ["examples/fish1.jpg", "fish", 0.35, 0.25, "Medium"], |
| ["examples/fish2.jpg", "all fish", 0.35, 0.25, "Medium"], |
| ], |
| inputs=[input_image, text_prompt, box_threshold, text_threshold, quality], |
| ) |
|
|
| demo.launch() |
|
|