Spaces:
Paused
Paused
| import os | |
| import glob | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image, ImageFilter | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.cm as cm | |
| from torchvision import transforms | |
| from transformers import ( | |
| AutoModelForImageSegmentation, | |
| AutoImageProcessor, | |
| AutoModelForDepthEstimation, | |
| ) | |
| import gradio as gr | |
| # ── Device ──────────────────────────────────────────────────────────────────── | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # ── HF login (RMBG-2.0 is gated) ───────────────────────────────────────────── | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| from huggingface_hub import login | |
| login(token=hf_token, add_to_git_credential=False) | |
| # ── Patch birefnet.py for newer PyTorch compatibility ───────────────────────── | |
| def _patch_birefnet(): | |
| candidates = glob.glob( | |
| os.path.expanduser( | |
| "~/.cache/huggingface/modules/transformers_modules" | |
| "/briaai/RMBG*/**/birefnet.py" | |
| ), | |
| recursive=True, | |
| ) | |
| if not candidates: | |
| from transformers import AutoConfig | |
| try: | |
| AutoConfig.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) | |
| except Exception: | |
| pass | |
| candidates = glob.glob( | |
| os.path.expanduser( | |
| "~/.cache/huggingface/modules/transformers_modules" | |
| "/briaai/RMBG*/**/birefnet.py" | |
| ), | |
| recursive=True, | |
| ) | |
| for path in candidates: | |
| with open(path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| old_variants = [ | |
| '[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]', | |
| '[float(x) for x in torch.linspace(0, drop_path_rate, sum(depths))]', | |
| ] | |
| fixed = ( | |
| '[drop_path_rate * i / (sum(depths) - 1) ' | |
| 'if sum(depths) > 1 else 0.0 ' | |
| 'for i in range(sum(depths))]' | |
| ) | |
| for old in old_variants: | |
| if old in content: | |
| content = content.replace(old, fixed) | |
| with open(path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| break | |
| _patch_birefnet() | |
| # ── Load models once at startup ─────────────────────────────────────────────── | |
| print(f"Loading models on {device}...") | |
| seg_model = AutoModelForImageSegmentation.from_pretrained( | |
| 'briaai/RMBG-2.0', trust_remote_code=True | |
| ).eval().to(device) | |
| depth_processor = AutoImageProcessor.from_pretrained( | |
| "depth-anything/Depth-Anything-V2-Small-hf" | |
| ) | |
| depth_model = AutoModelForDepthEstimation.from_pretrained( | |
| "depth-anything/Depth-Anything-V2-Small-hf" | |
| ).eval().to(device) | |
| print("Models ready.") | |
| _seg_transform = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| # ── Core pipeline ───────────────────────────────────────────────────────────── | |
| def process(image: Image.Image, gaussian_sigma: float, max_depth_blur: float): | |
| image_orig = image.convert("RGB") | |
| image_512 = image_orig.resize((512, 512), Image.LANCZOS) | |
| # 2.1 — Foreground segmentation (RMBG-2.0) | |
| inp = _seg_transform(image_orig).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| preds = seg_model(inp)[-1].sigmoid().cpu() | |
| pred_pil = transforms.ToPILImage()(preds[0].squeeze()) | |
| mask = pred_pil.resize(image_orig.size, Image.BILINEAR) | |
| binary_mask = Image.fromarray( | |
| (np.array(mask) > 127).astype(np.uint8) * 255, mode='L' | |
| ) | |
| # 2.2 — Gaussian background blur | |
| image_np = np.array(image_orig).astype(np.float32) | |
| mask_float = np.array(binary_mask).astype(np.float32) / 255.0 | |
| mask_3ch = np.stack([mask_float] * 3, axis=-1) | |
| blurred_bg = np.array( | |
| image_orig.filter(ImageFilter.GaussianBlur(radius=gaussian_sigma)) | |
| ).astype(np.float32) | |
| composited = Image.fromarray( | |
| (mask_3ch * image_np + (1.0 - mask_3ch) * blurred_bg).astype(np.uint8) | |
| ) | |
| # 2.3 — Monocular depth estimation (Depth Anything V2 Small) | |
| inputs = depth_processor(images=image_512, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = depth_model(**inputs) | |
| depth_t = F.interpolate( | |
| outputs.predicted_depth.unsqueeze(1), | |
| size=(512, 512), mode='bicubic', align_corners=False | |
| ).squeeze().cpu().numpy() | |
| depth_t = depth_t.max() - depth_t # invert: higher = farther | |
| depth_norm = (depth_t - depth_t.min()) / (depth_t.max() - depth_t.min() + 1e-8) | |
| depth_vis = Image.fromarray( | |
| (cm.inferno(depth_norm)[:, :, :3] * 255).astype(np.uint8) | |
| ) | |
| # 2.4 — Depth-based variable lens blur | |
| blur_radius_map = (depth_norm ** 2) * max_depth_blur # quadratic curve | |
| blur_levels = np.linspace(0, max_depth_blur, 5).tolist() | |
| image_512_np = np.array(image_512).astype(np.float32) | |
| blurred_imgs = [] | |
| for sigma in blur_levels: | |
| if sigma == 0: | |
| blurred_imgs.append(image_512_np.copy()) | |
| else: | |
| blurred_imgs.append( | |
| np.array(image_512.filter(ImageFilter.GaussianBlur(radius=sigma))).astype(np.float32) | |
| ) | |
| blurred_imgs = np.stack(blurred_imgs, axis=0) | |
| blur_levels_arr = np.array(blur_levels, dtype=np.float32) | |
| result = np.zeros_like(image_512_np) | |
| for i in range(len(blur_levels) - 1): | |
| lo, hi = blur_levels_arr[i], blur_levels_arr[i + 1] | |
| band = ((blur_radius_map >= lo) & (blur_radius_map < hi)).astype(np.float32) | |
| if band.sum() == 0: | |
| continue | |
| t = np.clip((blur_radius_map - lo) / (hi - lo + 1e-8), 0.0, 1.0) | |
| blended = (1.0 - t[:, :, None]) * blurred_imgs[i] + t[:, :, None] * blurred_imgs[i + 1] | |
| result += blended * band[:, :, None] | |
| result += blurred_imgs[-1] * (blur_radius_map >= blur_levels_arr[-1]).astype(np.float32)[:, :, None] | |
| lens_blur = Image.fromarray(np.clip(result, 0, 255).astype(np.uint8)) | |
| return binary_mask, composited, depth_vis, lens_blur | |
| # ── Gradio UI ───────────────────────────────────────────────────────────────── | |
| with gr.Blocks(title="Blurry") as demo: | |
| gr.Markdown( | |
| "# Blurry\n" | |
| "Upload an image to get foreground segmentation, background blur, " | |
| "depth estimation, and depth-of-field lens blur." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| gaussian_sigma = gr.Slider( | |
| minimum=1, maximum=50, value=15, step=1, | |
| label="Gaussian Blur σ (background blur, Part 2.2)" | |
| ) | |
| max_depth_blur = gr.Slider( | |
| minimum=1, maximum=50, value=15, step=1, | |
| label="Max Depth Blur radius (lens blur, Part 2.4)" | |
| ) | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| out_mask = gr.Image(label="2.1 — Segmentation Mask") | |
| out_blur = gr.Image(label="2.2 — Background Blur") | |
| with gr.Row(): | |
| out_depth = gr.Image(label="2.3 — Depth Map") | |
| out_lens = gr.Image(label="2.4 — Lens Blur (DoF)") | |
| run_btn.click( | |
| fn=process, | |
| inputs=[input_image, gaussian_sigma, max_depth_blur], | |
| outputs=[out_mask, out_blur, out_depth, out_lens], | |
| ) | |
| demo.launch() | |