Yang2001 commited on
Commit
8cb4c62
·
1 Parent(s): 551545a

Use local RMBG model, fix device placement, improve tqdm progress tracking

Browse files
app.py CHANGED
@@ -10,13 +10,10 @@ import numpy as np
10
  import base64
11
  import io
12
  import json
13
- import tempfile
14
  from datetime import datetime
15
  from typing import *
16
  from PIL import Image
17
 
18
- from gradio_client import Client as GradioClient, handle_file as gradio_handle_file
19
-
20
  import threading
21
  try:
22
  import nest_asyncio
@@ -140,7 +137,6 @@ def init_models():
140
  pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
141
  pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
142
 
143
- pipeline.rembg_model = None # Use remote BRIA-RMBG-2.0 instead
144
  pipeline.low_vram = False
145
  pipeline.cuda()
146
 
@@ -167,52 +163,6 @@ def init_models():
167
  'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
168
  }
169
 
170
- # ============================================================================
171
- # Remote Background Removal (same as Microsoft TRELLIS.2 official)
172
- # ============================================================================
173
-
174
- rmbg_client = GradioClient("briaai/BRIA-RMBG-2.0")
175
-
176
- def remove_background_remote(input: Image.Image) -> Image.Image:
177
- """Remove background using remote BRIA-RMBG-2.0 Space (no local GPU needed)."""
178
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
179
- input = input.convert('RGB')
180
- input.save(f.name)
181
- output = rmbg_client.predict(gradio_handle_file(f.name), api_name="/image")[0][0]
182
- result = Image.open(output)
183
- os.unlink(f.name)
184
- return result
185
-
186
- def preprocess_image_remote(input: Image.Image, bg_color: tuple = (0, 0, 0)) -> Image.Image:
187
- """Preprocess image using remote rembg (no GPU required)."""
188
- # If has alpha channel, use it directly
189
- has_alpha = False
190
- if input.mode == 'RGBA':
191
- alpha = np.array(input)[:, :, 3]
192
- if not np.all(alpha == 255):
193
- has_alpha = True
194
- max_size = max(input.size)
195
- scale = min(1, 1024 / max_size)
196
- if scale < 1:
197
- input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
198
- if has_alpha:
199
- output = input
200
- else:
201
- output = remove_background_remote(input)
202
- output_np = np.array(output)
203
- alpha = output_np[:, :, 3]
204
- bbox = np.argwhere(alpha > 0.8 * 255)
205
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
206
- center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
207
- size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
208
- size = int(size * 1)
209
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
210
- output = output.crop(bbox)
211
- output = np.array(output).astype(np.float32) / 255
212
- output = output[:, :, :3] * output[:, :, 3:4] + np.array(bg_color) / 255 * (1 - output[:, :, 3:4])
213
- output = Image.fromarray((output * 255).astype(np.uint8))
214
- return output
215
-
216
  # ============================================================================
217
  # Utilities
218
  # ============================================================================
@@ -395,9 +345,11 @@ async def progress_sse(request: Request):
395
  return StreamingResponse(event_stream(), media_type="text/event-stream")
396
 
397
  @app.api()
 
398
  def preprocess(image: FileData) -> FileData:
 
399
  img = Image.open(image["path"])
400
- processed = preprocess_image_remote(img)
401
  out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png")
402
  processed.save(out_path)
403
  return FileData(path=out_path)
@@ -540,4 +492,4 @@ if __name__ == "__main__":
540
  # Pre-initialize models before launching the server
541
  init_models()
542
 
543
- app.launch(show_error=True, share=True)
 
10
  import base64
11
  import io
12
  import json
 
13
  from datetime import datetime
14
  from typing import *
15
  from PIL import Image
16
 
 
 
17
  import threading
18
  try:
19
  import nest_asyncio
 
137
  pipeline.image_cond_model_shape_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["shape_1024"])
138
  pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
139
 
 
140
  pipeline.low_vram = False
141
  pipeline.cuda()
142
 
 
163
  'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
164
  }
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # ============================================================================
167
  # Utilities
168
  # ============================================================================
 
345
  return StreamingResponse(event_stream(), media_type="text/event-stream")
346
 
347
  @app.api()
348
+ @spaces.GPU(duration=30)
349
  def preprocess(image: FileData) -> FileData:
350
+ init_models()
351
  img = Image.open(image["path"])
352
+ processed = pipeline.preprocess_image(img)
353
  out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png")
354
  processed.save(out_path)
355
  return FileData(path=out_path)
 
492
  # Pre-initialize models before launching the server
493
  init_models()
494
 
495
+ app.launch(show_error=True, share=True)
trellis2/pipelines/pixal3d_image_to_3d.py CHANGED
@@ -120,7 +120,7 @@ class Pixal3DImageTo3DPipeline(Pipeline):
120
  pipeline.image_cond_model_shape_1024 = None
121
  pipeline.image_cond_model_tex_1024 = None
122
 
123
- pipeline.rembg_model = None # Skip local RMBG loading; use remote client instead
124
 
125
  pipeline.low_vram = args.get('low_vram', True)
126
  pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
 
120
  pipeline.image_cond_model_shape_1024 = None
121
  pipeline.image_cond_model_tex_1024 = None
122
 
123
+ pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
124
 
125
  pipeline.low_vram = args.get('low_vram', True)
126
  pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')