import os import glob import warnings warnings.filterwarnings('ignore') import torch import torch.nn.functional as F import numpy as np from PIL import Image, ImageFilter import matplotlib matplotlib.use('Agg') import matplotlib.cm as cm from torchvision import transforms from transformers import ( AutoModelForImageSegmentation, AutoImageProcessor, AutoModelForDepthEstimation, ) import gradio as gr # ── Device ──────────────────────────────────────────────────────────────────── device = 'cuda' if torch.cuda.is_available() else 'cpu' # ── HF login (RMBG-2.0 is gated) ───────────────────────────────────────────── hf_token = os.environ.get("HF_TOKEN") if hf_token: from huggingface_hub import login login(token=hf_token, add_to_git_credential=False) # ── Patch birefnet.py for newer PyTorch compatibility ───────────────────────── def _patch_birefnet(): candidates = glob.glob( os.path.expanduser( "~/.cache/huggingface/modules/transformers_modules" "/briaai/RMBG*/**/birefnet.py" ), recursive=True, ) if not candidates: from transformers import AutoConfig try: AutoConfig.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) except Exception: pass candidates = glob.glob( os.path.expanduser( "~/.cache/huggingface/modules/transformers_modules" "/briaai/RMBG*/**/birefnet.py" ), recursive=True, ) for path in candidates: with open(path, 'r', encoding='utf-8') as f: content = f.read() old_variants = [ '[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]', '[float(x) for x in torch.linspace(0, drop_path_rate, sum(depths))]', ] fixed = ( '[drop_path_rate * i / (sum(depths) - 1) ' 'if sum(depths) > 1 else 0.0 ' 'for i in range(sum(depths))]' ) for old in old_variants: if old in content: content = content.replace(old, fixed) with open(path, 'w', encoding='utf-8') as f: f.write(content) break _patch_birefnet() # ── Load models once at startup ─────────────────────────────────────────────── print(f"Loading models on {device}...") seg_model = AutoModelForImageSegmentation.from_pretrained( 'briaai/RMBG-2.0', trust_remote_code=True ).eval().to(device) depth_processor = AutoImageProcessor.from_pretrained( "depth-anything/Depth-Anything-V2-Small-hf" ) depth_model = AutoModelForDepthEstimation.from_pretrained( "depth-anything/Depth-Anything-V2-Small-hf" ).eval().to(device) print("Models ready.") _seg_transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # ── Core pipeline ───────────────────────────────────────────────────────────── def process(image: Image.Image, gaussian_sigma: float, max_depth_blur: float): image_orig = image.convert("RGB") image_512 = image_orig.resize((512, 512), Image.LANCZOS) # 2.1 — Foreground segmentation (RMBG-2.0) inp = _seg_transform(image_orig).unsqueeze(0).to(device) with torch.no_grad(): preds = seg_model(inp)[-1].sigmoid().cpu() pred_pil = transforms.ToPILImage()(preds[0].squeeze()) mask = pred_pil.resize(image_orig.size, Image.BILINEAR) binary_mask = Image.fromarray( (np.array(mask) > 127).astype(np.uint8) * 255, mode='L' ) # 2.2 — Gaussian background blur image_np = np.array(image_orig).astype(np.float32) mask_float = np.array(binary_mask).astype(np.float32) / 255.0 mask_3ch = np.stack([mask_float] * 3, axis=-1) blurred_bg = np.array( image_orig.filter(ImageFilter.GaussianBlur(radius=gaussian_sigma)) ).astype(np.float32) composited = Image.fromarray( (mask_3ch * image_np + (1.0 - mask_3ch) * blurred_bg).astype(np.uint8) ) # 2.3 — Monocular depth estimation (Depth Anything V2 Small) inputs = depth_processor(images=image_512, return_tensors="pt").to(device) with torch.no_grad(): outputs = depth_model(**inputs) depth_t = F.interpolate( outputs.predicted_depth.unsqueeze(1), size=(512, 512), mode='bicubic', align_corners=False ).squeeze().cpu().numpy() depth_t = depth_t.max() - depth_t # invert: higher = farther depth_norm = (depth_t - depth_t.min()) / (depth_t.max() - depth_t.min() + 1e-8) depth_vis = Image.fromarray( (cm.inferno(depth_norm)[:, :, :3] * 255).astype(np.uint8) ) # 2.4 — Depth-based variable lens blur blur_radius_map = (depth_norm ** 2) * max_depth_blur # quadratic curve blur_levels = np.linspace(0, max_depth_blur, 5).tolist() image_512_np = np.array(image_512).astype(np.float32) blurred_imgs = [] for sigma in blur_levels: if sigma == 0: blurred_imgs.append(image_512_np.copy()) else: blurred_imgs.append( np.array(image_512.filter(ImageFilter.GaussianBlur(radius=sigma))).astype(np.float32) ) blurred_imgs = np.stack(blurred_imgs, axis=0) blur_levels_arr = np.array(blur_levels, dtype=np.float32) result = np.zeros_like(image_512_np) for i in range(len(blur_levels) - 1): lo, hi = blur_levels_arr[i], blur_levels_arr[i + 1] band = ((blur_radius_map >= lo) & (blur_radius_map < hi)).astype(np.float32) if band.sum() == 0: continue t = np.clip((blur_radius_map - lo) / (hi - lo + 1e-8), 0.0, 1.0) blended = (1.0 - t[:, :, None]) * blurred_imgs[i] + t[:, :, None] * blurred_imgs[i + 1] result += blended * band[:, :, None] result += blurred_imgs[-1] * (blur_radius_map >= blur_levels_arr[-1]).astype(np.float32)[:, :, None] lens_blur = Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) return binary_mask, composited, depth_vis, lens_blur # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks(title="Blurry") as demo: gr.Markdown( "# Blurry\n" "Upload an image to get foreground segmentation, background blur, " "depth estimation, and depth-of-field lens blur." ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") gaussian_sigma = gr.Slider( minimum=1, maximum=50, value=15, step=1, label="Gaussian Blur σ (background blur, Part 2.2)" ) max_depth_blur = gr.Slider( minimum=1, maximum=50, value=15, step=1, label="Max Depth Blur radius (lens blur, Part 2.4)" ) run_btn = gr.Button("Run", variant="primary") with gr.Column(scale=2): with gr.Row(): out_mask = gr.Image(label="2.1 — Segmentation Mask") out_blur = gr.Image(label="2.2 — Background Blur") with gr.Row(): out_depth = gr.Image(label="2.3 — Depth Map") out_lens = gr.Image(label="2.4 — Lens Blur (DoF)") run_btn.click( fn=process, inputs=[input_image, gaussian_sigma, max_depth_blur], outputs=[out_mask, out_blur, out_depth, out_lens], ) demo.launch()