| """ |
| Surgical-DeSAM Gradio App for Hugging Face Spaces |
| Supports both Image and Video segmentation with ZeroGPU |
| """ |
| import os |
| import spaces |
| import gradio as gr |
| import torch |
| import numpy as np |
| import cv2 |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| import tempfile |
|
|
| |
| from models.detr_seg import DETR, SAMModel |
| from models.backbone import build_backbone |
| from models.transformer import build_transformer |
| from util.misc import NestedTensor |
|
|
| |
| MODEL_REPO = os.environ.get("MODEL_REPO", "IFMedTech/surgical-desam-weights") |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
| INSTRUMENT_CLASSES = ( |
| 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver', |
| 'monopolar_curved_scissors', 'ultrasound_probe', 'suction', |
| 'clip_applier', 'stapler' |
| ) |
|
|
| COLORS = [ |
| [0, 114, 189], [217, 83, 25], [237, 177, 32], |
| [126, 47, 142], [119, 172, 48], [77, 190, 238], |
| [162, 20, 47], [76, 76, 76] |
| ] |
|
|
| |
| model = None |
| seg_model = None |
| device = None |
|
|
|
|
| def download_weights(): |
| """Download model weights from private HF repo""" |
| weights_dir = "weights" |
| os.makedirs(weights_dir, exist_ok=True) |
| |
| desam_path = hf_hub_download( |
| repo_id=MODEL_REPO, |
| filename="surgical_desam_1024.pth", |
| token=HF_TOKEN, |
| local_dir=weights_dir |
| ) |
| |
| sam_path = hf_hub_download( |
| repo_id=MODEL_REPO, |
| filename="sam_vit_b_01ec64.pth", |
| token=HF_TOKEN, |
| local_dir=weights_dir |
| ) |
| |
| swin_dir = "swin_backbone" |
| os.makedirs(swin_dir, exist_ok=True) |
| hf_hub_download( |
| repo_id=MODEL_REPO, |
| filename="swin_base_patch4_window7_224_22kto1k.pth", |
| token=HF_TOKEN, |
| local_dir=swin_dir |
| ) |
| |
| return desam_path, sam_path |
|
|
|
|
| class Args: |
| """Mock args for model building""" |
| backbone = 'swin_B_224_22k' |
| dilation = False |
| position_embedding = 'sine' |
| hidden_dim = 256 |
| dropout = 0.1 |
| nheads = 8 |
| dim_feedforward = 2048 |
| enc_layers = 6 |
| dec_layers = 6 |
| pre_norm = False |
| num_queries = 100 |
| aux_loss = False |
| lr_backbone = 1e-5 |
| masks = False |
| dataset_file = 'endovis18' |
| device = 'cuda' |
| backbone_dir = './swin_backbone' |
|
|
|
|
| def load_models(): |
| """Load DETR and SAM models""" |
| global model, seg_model, device |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| desam_path, sam_path = download_weights() |
| |
| args = Args() |
| args.device = str(device) |
| |
| backbone = build_backbone(args) |
| transformer = build_transformer(args) |
| |
| model = DETR( |
| backbone, |
| transformer, |
| num_classes=9, |
| num_queries=args.num_queries, |
| aux_loss=args.aux_loss, |
| ) |
| |
| checkpoint = torch.load(desam_path, map_location='cpu', weights_only=False) |
| model.load_state_dict(checkpoint['model'], strict=False) |
| model.to(device) |
| model.eval() |
| |
| seg_model = SAMModel(device=device, ckpt_path=sam_path) |
| if 'seg_model' in checkpoint: |
| seg_model.load_state_dict(checkpoint['seg_model']) |
| seg_model.to(device) |
| seg_model.eval() |
| |
| print("Models loaded successfully!") |
|
|
|
|
| def preprocess_frame(frame): |
| """Preprocess frame for model input""" |
| img = cv2.resize(frame, (1024, 1024)) |
| img = img.astype(np.float32) / 255.0 |
| mean = np.array([0.485, 0.456, 0.406]) |
| std = np.array([0.229, 0.224, 0.225]) |
| img = (img - mean) / std |
| img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() |
| return img_tensor |
|
|
|
|
| def box_cxcywh_to_xyxy(x): |
| """Convert boxes from center format to corner format""" |
| x_c, y_c, w, h = x.unbind(-1) |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), |
| (x_c + 0.5 * w), (y_c + 0.5 * h)] |
| return torch.stack(b, dim=-1) |
|
|
|
|
| def process_single_frame(frame_rgb, h, w): |
| """Process a single frame and return segmented result""" |
| global model, seg_model, device |
| |
| img_tensor = preprocess_frame(frame_rgb).unsqueeze(0).to(device) |
| |
| mask = torch.zeros((1, 1024, 1024), dtype=torch.bool, device=device) |
| samples = NestedTensor(img_tensor, mask) |
| |
| with torch.no_grad(): |
| outputs, image_embeddings = model(samples) |
| |
| probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] |
| keep = probas.max(-1).values > 0.3 |
| |
| if not keep.any(): |
| return frame_rgb |
| |
| boxes = outputs['pred_boxes'][0, keep] |
| scores = probas[keep].max(-1).values.cpu().numpy() |
| labels = probas[keep].argmax(-1).cpu().numpy() |
| |
| boxes_scaled = box_cxcywh_to_xyxy(boxes) * torch.tensor([w, h, w, h], device=device) |
| boxes_np = boxes_scaled.cpu().numpy() |
| |
| low_res_masks, pred_masks, _ = seg_model( |
| img_tensor, boxes, image_embeddings, |
| sizes=(1024, 1024), add_noise=False |
| ) |
| masks_np = pred_masks.cpu().numpy() |
| |
| |
| result = frame_rgb.copy() |
| for i, (box, label, mask_pred, score) in enumerate(zip(boxes_np, labels, masks_np, scores)): |
| if score < 0.3: |
| continue |
| |
| color = COLORS[label % len(COLORS)] |
| |
| |
| mask_resized = cv2.resize(mask_pred, (w, h)) |
| mask_bool = mask_resized > 0.5 |
| overlay = result.copy() |
| overlay[mask_bool] = color |
| result = cv2.addWeighted(result, 0.6, overlay, 0.4, 0) |
| |
| |
| x1, y1, x2, y2 = box.astype(int) |
| cv2.rectangle(result, (x1, y1), (x2, y2), color, 2) |
| |
| |
| label_text = f"{INSTRUMENT_CLASSES[label]}: {score:.2f}" |
| cv2.putText(result, label_text, (x1, y1 - 10), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) |
| |
| return result |
|
|
|
|
| @spaces.GPU |
| def predict_image(image): |
| """Run inference on input image""" |
| global model, seg_model, device |
| |
| if model is None: |
| load_models() |
| |
| if image is None: |
| return None |
| |
| frame_rgb = np.array(image) |
| h, w = frame_rgb.shape[:2] |
| |
| result = process_single_frame(frame_rgb, h, w) |
| |
| return Image.fromarray(result) |
|
|
|
|
| @spaces.GPU(duration=300) |
| def predict_video(video_path, progress=gr.Progress()): |
| """Process video and return segmented video""" |
| global model, seg_model, device |
| |
| if model is None: |
| progress(0, desc="Loading models...") |
| load_models() |
| |
| if video_path is None: |
| return None |
| |
| |
| cap = cv2.VideoCapture(video_path) |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| |
| output_path = tempfile.mktemp(suffix=".mp4") |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
| |
| frame_count = 0 |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| |
| |
| result_rgb = process_single_frame(frame_rgb, height, width) |
| |
| |
| result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) |
| out.write(result_bgr) |
| |
| frame_count += 1 |
| progress(frame_count / total_frames, desc=f"Processing frame {frame_count}/{total_frames}") |
| |
| cap.release() |
| out.release() |
| |
| return output_path |
|
|
|
|
| |
| with gr.Blocks(title="Surgical-DeSAM", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🔬 Surgical-DeSAM") |
| gr.Markdown("Segment surgical instruments in images or videos using DeSAM architecture.") |
| |
| with gr.Tabs(): |
| |
| with gr.TabItem("🎬 Video Segmentation"): |
| with gr.Row(): |
| with gr.Column(): |
| input_video = gr.Video(label="Input Video") |
| video_btn = gr.Button("Segment Video", variant="primary") |
| with gr.Column(): |
| output_video = gr.Video(label="Segmentation Result") |
| |
| video_btn.click(fn=predict_video, inputs=input_video, outputs=output_video) |
| |
| gr.Examples( |
| examples=["examples/surgical_demo.mp4", |
| "examples/output.mp4"], |
| inputs=input_video, |
| label="Example Surgical Video" |
| ) |
| |
| with gr.TabItem("🖼️ Image Segmentation"): |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(type="pil", label="Input Image") |
| image_btn = gr.Button("Segment Image", variant="primary") |
| with gr.Column(): |
| output_image = gr.Image(type="pil", label="Segmentation Result") |
| |
| image_btn.click(fn=predict_image, inputs=input_image, outputs=output_image) |
| |
| gr.Examples( |
| examples=[ |
| "examples/example_2.png", |
| "examples/example_3.png", |
| "examples/example_4.png", |
| ], |
| inputs=input_image, |
| label="Example Surgical Images" |
| ) |
| |
|
|
| |
| gr.Markdown(""" |
| ## Detected Classes |
| Bipolar Forceps | Prograsp Forceps | Large Needle Driver | Monopolar Curved Scissors | |
| Ultrasound Probe | Suction | Clip Applier | Stapler |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|