import cv2 import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import gradio as gr from seg import U2NETP # Image processing utilities def load_image(path: str): """ Loads an image from the specified path and converts it to RGB format. """ img = cv2.imread(path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img / 255.0 def save_image(image: np.ndarray, path: str): """ Saves an image to the specified path. """ img = (image * 255).astype(np.uint8) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) cv2.imwrite(path, img) # Document Segmentation Model class U2NETP_DocSeg(nn.Module): def __init__(self, num_classes=1): super(U2NETP_DocSeg, self).__init__() self.u2netp = U2NETP(out_ch=num_classes) def forward(self, x): mask, *_ = self.u2netp(x) return mask # Initialize the document segmentation model docseg = U2NETP_DocSeg(num_classes=1) # Load pretrained weights docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth' checkpoint = torch.load(docseg_weight_path, map_location=torch.device('cpu')) docseg.load_state_dict(checkpoint[f"model_state_dict"]) docseg.eval() # Get segmentation mask def get_mask(image, confidence=0.5): org_shape = image.shape[:2] image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0) image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear') with torch.inference_mode(): # faster than no_grad mask = docseg(image_tensor) mask = (mask > confidence).float() mask = F.interpolate(mask, size=org_shape, mode='bilinear') return mask[0, 0] # keep tensor def overlay_mask(image, mask): image = torch.from_numpy(image).float() red = torch.tensor([1.0, 0, 0]).view(1, 3, 1, 1) mask = mask.unsqueeze(0) # (1, H, W) mask = mask.unsqueeze(0) # (1, 1, H, W) overlay = image.permute(2, 0, 1).unsqueeze(0) overlay = torch.where(mask > 0, red, overlay) blended = 0.7 * image.permute(2, 0, 1).unsqueeze(0) + 0.3 * overlay return blended[0].permute(1, 2, 0).cpu().numpy() def segment_image(image): """ Gradio function to segment input image and return overlay. """ image = image.astype(np.float32) / 255.0 # Normalize to [0, 1] mask = get_mask(image, confidence=0.5) overlayed_image = overlay_mask(image, mask) yield overlayed_image with gr.Blocks() as demo: gr.Markdown("## Real-time Document Segmentation") with gr.Row(): input_image = gr.Image(label="Input Image", type="numpy") output_image = gr.Image(label="Segmentation Overlay", type="numpy") examples = gr.Examples( examples=[ "./examples/sample.jpg", "./examples/manga.png", "./examples/invoice.png" ], inputs=input_image ) input_image.change(segment_image, inputs=input_image, outputs=output_image) demo.launch()