Blurry / app.py
SaiGanesh314's picture
Revert low_cpu_mem_usage=False, fix is in requirements
f41fc92 verified
import os
import glob
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageFilter
import matplotlib
matplotlib.use('Agg')
import matplotlib.cm as cm
from torchvision import transforms
from transformers import (
AutoModelForImageSegmentation,
AutoImageProcessor,
AutoModelForDepthEstimation,
)
import gradio as gr
# ── Device ────────────────────────────────────────────────────────────────────
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ── HF login (RMBG-2.0 is gated) ─────────────────────────────────────────────
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
from huggingface_hub import login
login(token=hf_token, add_to_git_credential=False)
# ── Patch birefnet.py for newer PyTorch compatibility ─────────────────────────
def _patch_birefnet():
candidates = glob.glob(
os.path.expanduser(
"~/.cache/huggingface/modules/transformers_modules"
"/briaai/RMBG*/**/birefnet.py"
),
recursive=True,
)
if not candidates:
from transformers import AutoConfig
try:
AutoConfig.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
except Exception:
pass
candidates = glob.glob(
os.path.expanduser(
"~/.cache/huggingface/modules/transformers_modules"
"/briaai/RMBG*/**/birefnet.py"
),
recursive=True,
)
for path in candidates:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
old_variants = [
'[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]',
'[float(x) for x in torch.linspace(0, drop_path_rate, sum(depths))]',
]
fixed = (
'[drop_path_rate * i / (sum(depths) - 1) '
'if sum(depths) > 1 else 0.0 '
'for i in range(sum(depths))]'
)
for old in old_variants:
if old in content:
content = content.replace(old, fixed)
with open(path, 'w', encoding='utf-8') as f:
f.write(content)
break
_patch_birefnet()
# ── Load models once at startup ───────────────────────────────────────────────
print(f"Loading models on {device}...")
seg_model = AutoModelForImageSegmentation.from_pretrained(
'briaai/RMBG-2.0', trust_remote_code=True
).eval().to(device)
depth_processor = AutoImageProcessor.from_pretrained(
"depth-anything/Depth-Anything-V2-Small-hf"
)
depth_model = AutoModelForDepthEstimation.from_pretrained(
"depth-anything/Depth-Anything-V2-Small-hf"
).eval().to(device)
print("Models ready.")
_seg_transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# ── Core pipeline ─────────────────────────────────────────────────────────────
def process(image: Image.Image, gaussian_sigma: float, max_depth_blur: float):
image_orig = image.convert("RGB")
image_512 = image_orig.resize((512, 512), Image.LANCZOS)
# 2.1 — Foreground segmentation (RMBG-2.0)
inp = _seg_transform(image_orig).unsqueeze(0).to(device)
with torch.no_grad():
preds = seg_model(inp)[-1].sigmoid().cpu()
pred_pil = transforms.ToPILImage()(preds[0].squeeze())
mask = pred_pil.resize(image_orig.size, Image.BILINEAR)
binary_mask = Image.fromarray(
(np.array(mask) > 127).astype(np.uint8) * 255, mode='L'
)
# 2.2 — Gaussian background blur
image_np = np.array(image_orig).astype(np.float32)
mask_float = np.array(binary_mask).astype(np.float32) / 255.0
mask_3ch = np.stack([mask_float] * 3, axis=-1)
blurred_bg = np.array(
image_orig.filter(ImageFilter.GaussianBlur(radius=gaussian_sigma))
).astype(np.float32)
composited = Image.fromarray(
(mask_3ch * image_np + (1.0 - mask_3ch) * blurred_bg).astype(np.uint8)
)
# 2.3 — Monocular depth estimation (Depth Anything V2 Small)
inputs = depth_processor(images=image_512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
depth_t = F.interpolate(
outputs.predicted_depth.unsqueeze(1),
size=(512, 512), mode='bicubic', align_corners=False
).squeeze().cpu().numpy()
depth_t = depth_t.max() - depth_t # invert: higher = farther
depth_norm = (depth_t - depth_t.min()) / (depth_t.max() - depth_t.min() + 1e-8)
depth_vis = Image.fromarray(
(cm.inferno(depth_norm)[:, :, :3] * 255).astype(np.uint8)
)
# 2.4 — Depth-based variable lens blur
blur_radius_map = (depth_norm ** 2) * max_depth_blur # quadratic curve
blur_levels = np.linspace(0, max_depth_blur, 5).tolist()
image_512_np = np.array(image_512).astype(np.float32)
blurred_imgs = []
for sigma in blur_levels:
if sigma == 0:
blurred_imgs.append(image_512_np.copy())
else:
blurred_imgs.append(
np.array(image_512.filter(ImageFilter.GaussianBlur(radius=sigma))).astype(np.float32)
)
blurred_imgs = np.stack(blurred_imgs, axis=0)
blur_levels_arr = np.array(blur_levels, dtype=np.float32)
result = np.zeros_like(image_512_np)
for i in range(len(blur_levels) - 1):
lo, hi = blur_levels_arr[i], blur_levels_arr[i + 1]
band = ((blur_radius_map >= lo) & (blur_radius_map < hi)).astype(np.float32)
if band.sum() == 0:
continue
t = np.clip((blur_radius_map - lo) / (hi - lo + 1e-8), 0.0, 1.0)
blended = (1.0 - t[:, :, None]) * blurred_imgs[i] + t[:, :, None] * blurred_imgs[i + 1]
result += blended * band[:, :, None]
result += blurred_imgs[-1] * (blur_radius_map >= blur_levels_arr[-1]).astype(np.float32)[:, :, None]
lens_blur = Image.fromarray(np.clip(result, 0, 255).astype(np.uint8))
return binary_mask, composited, depth_vis, lens_blur
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="Blurry") as demo:
gr.Markdown(
"# Blurry\n"
"Upload an image to get foreground segmentation, background blur, "
"depth estimation, and depth-of-field lens blur."
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
gaussian_sigma = gr.Slider(
minimum=1, maximum=50, value=15, step=1,
label="Gaussian Blur σ (background blur, Part 2.2)"
)
max_depth_blur = gr.Slider(
minimum=1, maximum=50, value=15, step=1,
label="Max Depth Blur radius (lens blur, Part 2.4)"
)
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=2):
with gr.Row():
out_mask = gr.Image(label="2.1 — Segmentation Mask")
out_blur = gr.Image(label="2.2 — Background Blur")
with gr.Row():
out_depth = gr.Image(label="2.3 — Depth Map")
out_lens = gr.Image(label="2.4 — Lens Blur (DoF)")
run_btn.click(
fn=process,
inputs=[input_image, gaussian_sigma, max_depth_blur],
outputs=[out_mask, out_blur, out_depth, out_lens],
)
demo.launch()