userhugginggit commited on
Commit
7fbde21
verified
1 Parent(s): 77659bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -67
app.py CHANGED
@@ -1,41 +1,70 @@
1
- import gradio as gr
2
- from loadimg import load_img
3
- #import spaces
4
- from transformers import AutoModelForImageSegmentation
5
  import torch
6
- from torchvision import transforms
7
- from typing import Union, Tuple
8
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
10
  birefnet = AutoModelForImageSegmentation.from_pretrained(
11
- "merve/BiRefNet", low_cpu_mem_usage=False, trust_remote_code=True, torch_dtype=torch.float32, device_map=None
12
- )
13
- birefnet = birefnet.eval()
14
- #birefnet.to("cuda")
15
-
16
- transform_image = transforms.Compose(
17
- [
18
- transforms.Resize((1024, 1024)),
19
- transforms.ToTensor(),
20
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
21
- ]
22
- )
23
 
24
- def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
25
- """
26
- Remove the background from an image and return both the transparent version and the original.
27
 
28
- This function performs background removal using a BiRefNet segmentation model. It is intended for use
29
- with image input (either uploaded or from a URL). The function returns a transparent PNG version of the image
30
- with the background removed, along with the original RGB version for comparison.
 
 
 
31
 
32
- Args:
33
- image (PIL.Image or str): The input image, either as a PIL object or a filepath/URL string.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- Returns:
36
- tuple:
37
- - origin (PIL.Image): The original RGB image, unchanged.
38
- - processed_image (PIL.Image): The input image with the background removed and transparency applied.
39
  """
40
  im = load_img(image, output_type="pil")
41
  im = im.convert("RGB")
@@ -43,64 +72,50 @@ def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
43
  processed_image = process(im)
44
  return (origin, processed_image)
45
 
46
- #@spaces.GPU
47
- def process(image: Image.Image) -> Image.Image:
48
- """
49
- Apply BiRefNet-based image segmentation to remove the background.
50
-
51
- This function preprocesses the input image, runs it through a BiRefNet segmentation model to obtain a mask,
52
- and applies the mask as an alpha (transparency) channel to the original image.
53
-
54
- Args:
55
- image (PIL.Image): The input RGB image.
56
-
57
- Returns:
58
- PIL.Image: The image with the background removed, using the segmentation mask as transparency.
59
- """
60
- image_size = image.size
61
- input_images = transform_image(image).unsqueeze(0)
62
- with torch.inference_mode():
63
- preds = birefnet(input_images)[-1].sigmoid().detach().cpu()
64
- pred = preds[0].squeeze()
65
- pred_pil = transforms.ToPILImage()(pred)
66
- mask = pred_pil.resize(image_size)
67
- image.putalpha(mask)
68
- return image
69
-
70
  def process_file(f: str) -> str:
71
  """
72
- Load an image file from disk, remove the background, and save the output as a transparent PNG.
73
-
74
- Args:
75
- f (str): Filepath of the image to process.
76
-
77
- Returns:
78
- str: Path to the saved PNG image with background removed.
79
  """
80
  name_path = f.rsplit(".", 1)[0] + ".png"
81
  im = load_img(f, output_type="pil")
82
  im = im.convert("RGB")
83
  transparent = process(im)
84
- transparent.save(name_path)
85
  return name_path
86
 
 
 
 
 
87
  slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
88
  slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
 
89
  image_upload = gr.Image(label="Upload an image")
90
  image_file_upload = gr.Image(label="Upload an image", type="filepath")
91
  url_input = gr.Textbox(label="Paste an image URL")
92
  output_file = gr.File(label="Output PNG File")
93
 
94
- # Example images
95
- chameleon = load_img("butterfly.jpg", output_type="pil")
96
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
97
 
98
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
 
 
 
 
 
 
 
 
 
99
  tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
100
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
101
 
102
  demo = gr.TabbedInterface(
103
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
 
 
104
  )
105
 
106
  if __name__ == "__main__":
 
1
+ import os
 
 
 
2
  import torch
 
 
3
  from PIL import Image
4
+ from typing import Union, Tuple
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+ import gradio as gr
8
+ from loadimg import load_img
9
+
10
+ # =========================================================================
11
+ # CONFIGURACI脫N DE DISPOSITIVO (CPU)
12
+ # =========================================================================
13
+ DEVICE = "cpu"
14
+
15
+ print(f"--- Cargando BiRefNet en {DEVICE.upper()} ---")
16
 
17
+ # Cargamos el modelo directamente del Hub de Hugging Face
18
  birefnet = AutoModelForImageSegmentation.from_pretrained(
19
+ "merve/BiRefNet",
20
+ trust_remote_code=True,
21
+ torch_dtype=torch.float32
22
+ ).to(DEVICE)
 
 
 
 
 
 
 
 
23
 
24
+ birefnet.eval()
25
+ print("Modelo cargado correctamente en CPU.")
 
26
 
27
+ # Transformaciones necesarias para el modelo
28
+ transform_image = transforms.Compose([
29
+ transforms.Resize((1024, 1024)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
32
+ ])
33
 
34
+ # =========================================================================
35
+ # FUNCIONES DE PROCESAMIENTO
36
+ # =========================================================================
37
+
38
+ def process(image: Image.Image) -> Image.Image:
39
+ """
40
+ Aplica BiRefNet para remover el fondo de la imagen usando CPU.
41
+ """
42
+ image_size = image.size
43
+
44
+ # 1. Preparar el tensor para la red
45
+ input_tensor = transform_image(image).unsqueeze(0).to(DEVICE)
46
+
47
+ # 2. Inferencia (Paso por la red neuronal sin almacenar gradientes)
48
+ with torch.no_grad():
49
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
50
+
51
+ # 3. Crear la m谩scara Alfa
52
+ mask = preds[0].squeeze()
53
+ mask_pil = transforms.ToPILImage()(mask)
54
+
55
+ # 4. Ajustar m谩scara al tama帽o original con alta calidad (LANCZOS)
56
+ mask_final = mask_pil.resize(image_size, Image.LANCZOS)
57
+
58
+ # 5. Aplicar transparencia a la imagen original
59
+ output_image = image.copy()
60
+ output_image.putalpha(mask_final)
61
+
62
+ return output_image
63
 
64
+ def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
65
+ """
66
+ Funci贸n para las pesta帽as de Gradio (Subida de Imagen y URL).
67
+ Devuelve la imagen original y la versi贸n procesada para el ImageSlider.
68
  """
69
  im = load_img(image, output_type="pil")
70
  im = im.convert("RGB")
 
72
  processed_image = process(im)
73
  return (origin, processed_image)
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def process_file(f: str) -> str:
76
  """
77
+ Funci贸n para la pesta帽a de archivos. Guarda y devuelve la ruta del PNG.
 
 
 
 
 
 
78
  """
79
  name_path = f.rsplit(".", 1)[0] + ".png"
80
  im = load_img(f, output_type="pil")
81
  im = im.convert("RGB")
82
  transparent = process(im)
83
+ transparent.save(name_path, "PNG")
84
  return name_path
85
 
86
+ # =========================================================================
87
+ # INTERFAZ GRADIO
88
+ # =========================================================================
89
+
90
  slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
91
  slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
92
+
93
  image_upload = gr.Image(label="Upload an image")
94
  image_file_upload = gr.Image(label="Upload an image", type="filepath")
95
  url_input = gr.Textbox(label="Paste an image URL")
96
  output_file = gr.File(label="Output PNG File")
97
 
98
+ # Ejemplos por defecto
99
+ example_image_path = "butterfly.jpg"
100
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
101
 
102
+ # Carga segura de la imagen de ejemplo local para evitar crasheos si no se ha subido a煤n
103
+ try:
104
+ chameleon = load_img(example_image_path, output_type="pil")
105
+ examples_img = [chameleon]
106
+ examples_file = [example_image_path]
107
+ except Exception:
108
+ examples_img = None
109
+ examples_file = None
110
+
111
+ tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=examples_img, api_name="image")
112
  tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
113
+ tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=examples_file, api_name="png")
114
 
115
  demo = gr.TabbedInterface(
116
+ [tab1, tab2, tab3],
117
+ ["Image Upload", "URL Input", "File Output"],
118
+ title="Background Removal Tool (CPU Edition)"
119
  )
120
 
121
  if __name__ == "__main__":