| import spaces |
| import gradio as gr |
| import torch |
| import numpy as np |
| from PIL import Image |
| import requests |
| import warnings |
| import json |
| import os |
| from datetime import datetime |
| from threading import Thread |
| from queue import Queue |
| import time |
| warnings.filterwarnings("ignore") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| from transformers import Sam3Processor, Sam3Model |
| model = Sam3Model.from_pretrained("DiffusionWave/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device) |
| processor = Sam3Processor.from_pretrained("DiffusionWave/sam3") |
|
|
| |
| job_queue = Queue() |
| results_store = {} |
| job_counter = 0 |
|
|
| |
| HISTORY_DIR = "segmentation_history" |
| HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json") |
| CROPS_DIR = os.path.join(HISTORY_DIR, "crops") |
| os.makedirs(HISTORY_DIR, exist_ok=True) |
| os.makedirs(CROPS_DIR, exist_ok=True) |
|
|
| def load_history(): |
| """Load segmentation history from file""" |
| if os.path.exists(HISTORY_FILE): |
| try: |
| with open(HISTORY_FILE, 'r') as f: |
| return json.load(f) |
| except: |
| return [] |
| return [] |
|
|
| def save_history(history): |
| """Save segmentation history to file""" |
| with open(HISTORY_FILE, 'w') as f: |
| json.dump(history, f, indent=2) |
|
|
| def crop_segmented_objects(image: Image.Image, masks, text: str, timestamp: str): |
| """ |
| Crop individual objects from masks and save them |
| Returns list of cropped image paths |
| """ |
| cropped_paths = [] |
| image_np = np.array(image) |
| |
| for i, mask in enumerate(masks): |
| |
| if isinstance(mask, torch.Tensor): |
| mask_np = mask.cpu().numpy() |
| else: |
| mask_np = mask |
| |
| |
| rows = np.any(mask_np, axis=1) |
| cols = np.any(mask_np, axis=0) |
| |
| if not rows.any() or not cols.any(): |
| continue |
| |
| y_min, y_max = np.where(rows)[0][[0, -1]] |
| x_min, x_max = np.where(cols)[0][[0, -1]] |
| |
| |
| padding = 10 |
| y_min = max(0, y_min - padding) |
| y_max = min(image_np.shape[0], y_max + padding) |
| x_min = max(0, x_min - padding) |
| x_max = min(image_np.shape[1], x_max + padding) |
| |
| |
| cropped = image_np[y_min:y_max, x_min:x_max] |
| |
| |
| mask_crop = mask_np[y_min:y_max, x_min:x_max] |
| |
| |
| cropped_rgba = np.zeros((*cropped.shape[:2], 4), dtype=np.uint8) |
| cropped_rgba[:, :, :3] = cropped |
| cropped_rgba[:, :, 3] = (mask_crop * 255).astype(np.uint8) |
| |
| |
| crop_filename = f"crop_{timestamp.replace(':', '-').replace(' ', '_')}_{text}_{i+1}.png" |
| crop_path = os.path.join(CROPS_DIR, crop_filename) |
| Image.fromarray(cropped_rgba).save(crop_path) |
| cropped_paths.append(crop_path) |
| |
| return cropped_paths |
|
|
| def add_to_history(image_path, prompt, n_masks, scores, timestamp, crop_paths): |
| """Add a new entry to history""" |
| history = load_history() |
| entry = { |
| "id": len(history) + 1, |
| "timestamp": timestamp, |
| "image_path": image_path, |
| "prompt": prompt, |
| "n_masks": n_masks, |
| "scores": scores, |
| "crop_paths": crop_paths |
| } |
| history.insert(0, entry) |
| |
| history = history[:100] |
| save_history(history) |
| return history |
|
|
| @spaces.GPU() |
| def segment_core(image: Image.Image, text: str, threshold: float, mask_threshold: float, save_crops: bool = True): |
| """ |
| Core segmentation function - can be called independently |
| """ |
| if image is None: |
| return None, "⌠Please upload an image.", None, [] |
| |
| if not text.strip(): |
| return (image, []), "⌠Please enter a text prompt.", None, [] |
| |
| try: |
| inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device) |
| |
| for key in inputs: |
| if inputs[key].dtype == torch.float32: |
| inputs[key] = inputs[key].to(model.dtype) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| results = processor.post_process_instance_segmentation( |
| outputs, |
| threshold=threshold, |
| mask_threshold=mask_threshold, |
| target_sizes=inputs.get("original_sizes").tolist() |
| )[0] |
| |
| n_masks = len(results['masks']) |
| if n_masks == 0: |
| return (image, []), f"⌠No objects found matching '{text}' (try adjusting thresholds).", None, [] |
| |
| |
| annotations = [] |
| for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])): |
| mask_np = mask.cpu().numpy().astype(np.float32) |
| label = f"{text} #{i+1} ({score:.2f})" |
| annotations.append((mask_np, label)) |
| |
| scores_list = results['scores'].cpu().numpy().tolist() |
| scores_text = ", ".join([f"{s:.2f}" for s in scores_list[:5]]) |
| info = f"✅ Found **{n_masks}** objects matching **'{text}'**\n" |
| info += f"Confidence scores: {scores_text}{'...' if n_masks > 5 else ''}\n" |
| |
| |
| cropped_images = [] |
| if save_crops: |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| crop_paths = crop_segmented_objects(image, results['masks'], text, timestamp) |
| info += f"âœ‚ï¸ Extracted **{len(crop_paths)}** cropped objects" |
| |
| |
| for path in crop_paths[:10]: |
| if os.path.exists(path): |
| cropped_images.append(Image.open(path)) |
| else: |
| crop_paths = [] |
| |
| metadata = { |
| "n_masks": n_masks, |
| "scores": scores_list, |
| "crop_paths": crop_paths, |
| "masks": results['masks'] |
| } |
| |
| return (image, annotations), info, metadata, cropped_images |
| |
| except Exception as e: |
| return (image, []), f"⌠Error during segmentation: {str(e)}", None, [] |
|
|
| def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float): |
| """ |
| Frontend segment function - with history saving |
| """ |
| result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True) |
| |
| |
| if metadata and metadata["n_masks"] > 0: |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| |
| img_filename = f"img_{int(time.time())}.jpg" |
| img_path = os.path.join(HISTORY_DIR, img_filename) |
| image.save(img_path) |
| |
| add_to_history( |
| img_path, |
| text, |
| metadata["n_masks"], |
| metadata["scores"], |
| timestamp, |
| metadata["crop_paths"] |
| ) |
| |
| return result, info, cropped_images |
|
|
| def background_worker(): |
| """ |
| Background worker thread - processes jobs independently |
| """ |
| while True: |
| job = job_queue.get() |
| if job is None: |
| break |
| |
| job_id, image, text, threshold, mask_threshold = job |
| |
| try: |
| result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True) |
| results_store[job_id] = { |
| "status": "completed", |
| "result": result, |
| "info": info, |
| "metadata": metadata, |
| "cropped_images": cropped_images, |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| } |
| |
| |
| if metadata and metadata["n_masks"] > 0: |
| img_filename = f"bg_img_{job_id}.jpg" |
| img_path = os.path.join(HISTORY_DIR, img_filename) |
| image.save(img_path) |
| add_to_history( |
| img_path, |
| text, |
| metadata["n_masks"], |
| metadata["scores"], |
| results_store[job_id]["timestamp"], |
| metadata["crop_paths"] |
| ) |
| |
| except Exception as e: |
| results_store[job_id] = { |
| "status": "failed", |
| "error": str(e), |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| } |
| |
| job_queue.task_done() |
|
|
| |
| worker_thread = Thread(target=background_worker, daemon=True) |
| worker_thread.start() |
|
|
| def submit_background_job(image, text, threshold, mask_threshold): |
| """Submit a job to background queue""" |
| global job_counter |
| if image is None or not text.strip(): |
| return "⌠Please provide image and text prompt.", "" |
| |
| job_counter += 1 |
| job_id = job_counter |
| |
| job_queue.put((job_id, image, text, threshold, mask_threshold)) |
| results_store[job_id] = {"status": "processing"} |
| |
| return f"✅ Job #{job_id} submitted to background queue.", f"{job_id}" |
|
|
| def check_background_job(job_id_str): |
| """Check status of background job""" |
| if not job_id_str.strip(): |
| return "⌠Please enter a job ID.", None, [] |
| |
| try: |
| job_id = int(job_id_str) |
| if job_id not in results_store: |
| return f"⌠Job #{job_id} not found.", None, [] |
| |
| job_data = results_store[job_id] |
| status = job_data["status"] |
| |
| if status == "processing": |
| return f"â³ Job #{job_id} is still processing...", None, [] |
| elif status == "completed": |
| return ( |
| f"✅ Job #{job_id} completed!\n{job_data['info']}", |
| job_data["result"], |
| job_data.get("cropped_images", []) |
| ) |
| else: |
| return f"⌠Job #{job_id} failed: {job_data.get('error', 'Unknown error')}", None, [] |
| |
| except ValueError: |
| return "⌠Invalid job ID format.", None, [] |
|
|
| def load_history_display(): |
| """Load and format history for display""" |
| history = load_history() |
| if not history: |
| return "📠No history yet. Start segmenting images!" |
| |
| display = "## Segmentation History\n\n" |
| for entry in history[:20]: |
| display += f"**#{entry['id']}** - {entry['timestamp']}\n" |
| display += f"- Prompt: `{entry['prompt']}`\n" |
| display += f"- Found: {entry['n_masks']} objects\n" |
| display += f"- Cropped: {len(entry.get('crop_paths', []))} images\n" |
| display += f"- Top scores: {', '.join([f'{s:.2f}' for s in entry['scores'][:3]])}\n\n" |
| |
| return display |
|
|
| def load_history_item(item_id): |
| """Load a specific history item with cropped images""" |
| history = load_history() |
| for entry in history: |
| if entry['id'] == int(item_id): |
| info = f"**History item #{entry['id']}**\n" |
| info += f"Timestamp: {entry['timestamp']}\n" |
| info += f"Prompt: `{entry['prompt']}`\n" |
| info += f"Objects found: {entry['n_masks']}\n" |
| info += f"Cropped images: {len(entry.get('crop_paths', []))}" |
| |
| image = None |
| if os.path.exists(entry['image_path']): |
| image = Image.open(entry['image_path']) |
| |
| |
| cropped_images = [] |
| for crop_path in entry.get('crop_paths', [])[:10]: |
| if os.path.exists(crop_path): |
| cropped_images.append(Image.open(crop_path)) |
| |
| return image, entry['prompt'], info, cropped_images |
| |
| return None, "", f"⌠History item #{item_id} not found", [] |
|
|
| def clear_all(): |
| """Clear all inputs and outputs""" |
| return None, "", None, 0.5, 0.5, "📠Enter a prompt and click **Segment** to start.", [] |
|
|
| def segment_example(image_path: str, prompt: str): |
| """Handle example clicks""" |
| if image_path.startswith("http"): |
| image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB") |
| else: |
| image = Image.open(image_path).convert("RGB") |
| return segment(image, prompt, 0.5, 0.5) |
|
|
| |
| with gr.Blocks( |
| theme=gr.themes.Soft(), |
| title="SAM3 - Promptable Concept Segmentation", |
| css=".gradio-container {max-width: 1600px !important;}" |
| ) as demo: |
| gr.Markdown( |
| """ |
| # SAM3 - Promptable Concept Segmentation (PCS) |
| |
| **SAM3** performs zero-shot instance segmentation using natural language prompts. |
| Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks + cropped objects. |
| |
| Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) |
| """ |
| ) |
| |
| with gr.Tabs(): |
| |
| with gr.Tab("🎯 Segmentation"): |
| gr.Markdown("### Inputs") |
| with gr.Row(variant="panel"): |
| image_input = gr.Image( |
| label="Input Image", |
| type="pil", |
| height=400, |
| ) |
| image_output = gr.AnnotatedImage( |
| label="Output (Segmented Image)", |
| height=400, |
| show_legend=True, |
| ) |
| |
| with gr.Row(): |
| text_input = gr.Textbox( |
| label="Text Prompt", |
| placeholder="e.g., person, ear, cat, bicycle...", |
| scale=3 |
| ) |
| clear_btn = gr.Button("🔄 Clear", size="sm", variant="secondary") |
| |
| with gr.Row(): |
| thresh_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.5, |
| step=0.01, |
| label="Detection Threshold", |
| info="Higher = fewer detections" |
| ) |
| mask_thresh_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.5, |
| step=0.01, |
| label="Mask Threshold", |
| info="Higher = sharper masks" |
| ) |
| |
| info_output = gr.Markdown( |
| value="📠Enter a prompt and click **Segment** to start.", |
| label="Info / Results" |
| ) |
| |
| segment_btn = gr.Button("🎯 Segment Now", variant="primary", size="lg") |
| |
| gr.Markdown("### âœ‚ï¸ Cropped Objects") |
| cropped_gallery = gr.Gallery( |
| label="Extracted Objects", |
| columns=5, |
| height=300, |
| object_fit="contain" |
| ) |
| |
| gr.Examples( |
| examples=[ |
| ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"], |
| ], |
| inputs=[image_input, text_input], |
| outputs=[image_output, info_output, cropped_gallery], |
| fn=segment_example, |
| cache_examples=False, |
| ) |
| |
| |
| with gr.Tab("âš™ï¸ Background Processing"): |
| gr.Markdown( |
| """ |
| ### Background Job Queue |
| Submit segmentation jobs that run independently in the background. |
| Useful for batch processing or when you want to continue working while processing. |
| """ |
| ) |
| |
| with gr.Row(): |
| bg_image_input = gr.Image(label="Image", type="pil", height=300) |
| bg_status_output = gr.Markdown("📠Submit a job to start background processing.") |
| |
| with gr.Row(): |
| bg_text_input = gr.Textbox(label="Text Prompt", placeholder="e.g., person, car...") |
| bg_job_id_output = gr.Textbox(label="Job ID", interactive=False) |
| |
| with gr.Row(): |
| bg_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Detection Threshold") |
| bg_mask_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Mask Threshold") |
| |
| bg_submit_btn = gr.Button("📤 Submit Background Job", variant="primary") |
| |
| gr.Markdown("---") |
| gr.Markdown("### Check Job Status") |
| |
| with gr.Row(): |
| check_job_id = gr.Textbox(label="Enter Job ID", placeholder="e.g., 1") |
| check_btn = gr.Button("🔠Check Status", variant="secondary") |
| |
| check_status_output = gr.Markdown("Enter a job ID and click Check Status.") |
| check_result_output = gr.AnnotatedImage(label="Result", height=400) |
| |
| gr.Markdown("### Cropped Objects from Job") |
| check_cropped_gallery = gr.Gallery( |
| label="Extracted Objects", |
| columns=5, |
| height=300, |
| object_fit="contain" |
| ) |
| |
| |
| with gr.Tab("📚 History"): |
| gr.Markdown("### Segmentation History") |
| |
| with gr.Row(): |
| refresh_history_btn = gr.Button("🔄 Refresh History", variant="secondary") |
| history_item_id = gr.Textbox(label="Load History Item #", placeholder="Enter ID") |
| load_history_btn = gr.Button("📂 Load Item", variant="primary") |
| |
| history_display = gr.Markdown(load_history_display()) |
| |
| gr.Markdown("---") |
| gr.Markdown("### Loaded History Item") |
| |
| with gr.Row(): |
| history_image = gr.Image(label="Original Image", type="pil", height=300) |
| history_info = gr.Markdown("Select a history item to view.") |
| |
| history_prompt = gr.Textbox(label="Prompt", interactive=False) |
| |
| gr.Markdown("### Cropped Objects from History") |
| history_cropped_gallery = gr.Gallery( |
| label="Extracted Objects", |
| columns=5, |
| height=300, |
| object_fit="contain" |
| ) |
| |
| |
| clear_btn.click( |
| fn=clear_all, |
| outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output, cropped_gallery] |
| ) |
| |
| segment_btn.click( |
| fn=segment, |
| inputs=[image_input, text_input, thresh_slider, mask_thresh_slider], |
| outputs=[image_output, info_output, cropped_gallery] |
| ) |
| |
| bg_submit_btn.click( |
| fn=submit_background_job, |
| inputs=[bg_image_input, bg_text_input, bg_thresh, bg_mask_thresh], |
| outputs=[bg_status_output, bg_job_id_output] |
| ) |
| |
| check_btn.click( |
| fn=check_background_job, |
| inputs=[check_job_id], |
| outputs=[check_status_output, check_result_output, check_cropped_gallery] |
| ) |
| |
| refresh_history_btn.click( |
| fn=load_history_display, |
| outputs=[history_display] |
| ) |
| |
| load_history_btn.click( |
| fn=load_history_item, |
| inputs=[history_item_id], |
| outputs=[history_image, history_prompt, history_info, history_cropped_gallery] |
| ) |
| |
| gr.Markdown( |
| """ |
| ### Notes |
| - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3) |
| - Background jobs run independently and are tracked by Job ID |
| - All segmented objects are automatically cropped and saved |
| - Cropped images have transparent backgrounds (PNG format) |
| - History is saved automatically and persists across sessions |
| - GPU recommended for faster inference |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) |