Yang2001 commited on
Commit
53ad659
·
1 Parent(s): c1ef43c

refactor: use remote BRIA-RMBG-2.0 for background removal (no local GPU needed)

Browse files
Files changed (2) hide show
  1. app.py +54 -4
  2. autotune_cache.json +30 -0
app.py CHANGED
@@ -10,10 +10,13 @@ import numpy as np
10
  import base64
11
  import io
12
  import json
 
13
  from datetime import datetime
14
  from typing import *
15
  from PIL import Image
16
 
 
 
17
  import threading
18
  try:
19
  import nest_asyncio
@@ -138,6 +141,8 @@ def init_models():
138
  pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
139
 
140
  pipeline.cuda()
 
 
141
 
142
  print("[NAF] Pre-loading NAF upsampler model...")
143
  for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
@@ -156,6 +161,52 @@ def init_models():
156
  'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
157
  }
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # ============================================================================
160
  # Utilities
161
  # ============================================================================
@@ -332,11 +383,9 @@ async def progress_sse(request: Request):
332
  return StreamingResponse(event_stream(), media_type="text/event-stream")
333
 
334
  @app.api()
335
- @spaces.GPU(duration=60)
336
  def preprocess(image: FileData) -> FileData:
337
- init_models()
338
  img = Image.open(image["path"])
339
- processed = pipeline.preprocess_image(img)
340
  out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png")
341
  processed.save(out_path)
342
  return FileData(path=out_path)
@@ -369,7 +418,8 @@ def generate_3d(
369
  hr_resolution = int(resolution)
370
 
371
  img = Image.open(image["path"])
372
- image_preprocessed = pipeline.preprocess_image(img)
 
373
  temp_processed_path = os.path.join(TMP_DIR, f"temp_proc_{session_id[:8]}_{int(time.time()*1000)}.png")
374
  image_preprocessed.save(temp_processed_path)
375
 
 
10
  import base64
11
  import io
12
  import json
13
+ import tempfile
14
  from datetime import datetime
15
  from typing import *
16
  from PIL import Image
17
 
18
+ from gradio_client import Client as GradioClient, handle_file as gradio_handle_file
19
+
20
  import threading
21
  try:
22
  import nest_asyncio
 
141
  pipeline.image_cond_model_tex_1024 = build_image_cond_model(IMAGE_COND_CONFIGS["tex_1024"])
142
 
143
  pipeline.cuda()
144
+ pipeline.rembg_model = None # Use remote BRIA-RMBG-2.0 instead
145
+ pipeline.low_vram = False
146
 
147
  print("[NAF] Pre-loading NAF upsampler model...")
148
  for attr in ['image_cond_model_ss', 'image_cond_model_shape_512', 'image_cond_model_shape_1024', 'image_cond_model_tex_1024']:
 
161
  'courtyard': EnvMap(torch.tensor(cv2.cvtColor(cv2.imread(os.path.join(_base, 'assets/hdri/courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda')),
162
  }
163
 
164
+ # ============================================================================
165
+ # Remote Background Removal (same as Microsoft TRELLIS.2 official)
166
+ # ============================================================================
167
+
168
+ rmbg_client = GradioClient("briaai/BRIA-RMBG-2.0")
169
+
170
+ def remove_background_remote(input: Image.Image) -> Image.Image:
171
+ """Remove background using remote BRIA-RMBG-2.0 Space (no local GPU needed)."""
172
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
173
+ input = input.convert('RGB')
174
+ input.save(f.name)
175
+ output = rmbg_client.predict(gradio_handle_file(f.name), api_name="/image")[0][0]
176
+ result = Image.open(output)
177
+ os.unlink(f.name)
178
+ return result
179
+
180
+ def preprocess_image_remote(input: Image.Image, bg_color: tuple = (0, 0, 0)) -> Image.Image:
181
+ """Preprocess image using remote rembg (no GPU required)."""
182
+ # If has alpha channel, use it directly
183
+ has_alpha = False
184
+ if input.mode == 'RGBA':
185
+ alpha = np.array(input)[:, :, 3]
186
+ if not np.all(alpha == 255):
187
+ has_alpha = True
188
+ max_size = max(input.size)
189
+ scale = min(1, 1024 / max_size)
190
+ if scale < 1:
191
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
192
+ if has_alpha:
193
+ output = input
194
+ else:
195
+ output = remove_background_remote(input)
196
+ output_np = np.array(output)
197
+ alpha = output_np[:, :, 3]
198
+ bbox = np.argwhere(alpha > 0.8 * 255)
199
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
200
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
201
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
202
+ size = int(size * 1)
203
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
204
+ output = output.crop(bbox)
205
+ output = np.array(output).astype(np.float32) / 255
206
+ output = output[:, :, :3] * output[:, :, 3:4] + np.array(bg_color) / 255 * (1 - output[:, :, 3:4])
207
+ output = Image.fromarray((output * 255).astype(np.uint8))
208
+ return output
209
+
210
  # ============================================================================
211
  # Utilities
212
  # ============================================================================
 
383
  return StreamingResponse(event_stream(), media_type="text/event-stream")
384
 
385
  @app.api()
 
386
  def preprocess(image: FileData) -> FileData:
 
387
  img = Image.open(image["path"])
388
+ processed = preprocess_image_remote(img)
389
  out_path = os.path.join(TMP_DIR, f"preprocessed_{int(time.time()*1000)}.png")
390
  processed.save(out_path)
391
  return FileData(path=out_path)
 
418
  hr_resolution = int(resolution)
419
 
420
  img = Image.open(image["path"])
421
+ # Image is already preprocessed by /preprocess endpoint, use directly
422
+ image_preprocessed = img
423
  temp_processed_path = os.path.join(TMP_DIR, f"temp_proc_{session_id[:8]}_{int(time.time()*1000)}.png")
424
  image_preprocessed.save(temp_processed_path)
425
 
autotune_cache.json CHANGED
@@ -24914,6 +24914,36 @@
24914
  "reg_inc_consumer": 0,
24915
  "maxnreg": null,
24916
  "pre_hook": null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24917
  }
24918
  },
24919
  "flex_gemm.kernels.triton.spconv.sparse_submanifold_conv_bwd_implicit_gemm.sparse_submanifold_conv_bwd_input_implicit_gemm_kernel": {
 
24914
  "reg_inc_consumer": 0,
24915
  "maxnreg": null,
24916
  "pre_hook": null
24917
+ },
24918
+ "(21, 4194304, 6, 8, 'torch.float32', 'torch.uint32', 'torch.float32', 'torch.float32')": {
24919
+ "kwargs": {
24920
+ "BM": 16,
24921
+ "BK": 8
24922
+ },
24923
+ "num_warps": 2,
24924
+ "num_ctas": 1,
24925
+ "num_stages": 2,
24926
+ "num_buffers_warp_spec": 0,
24927
+ "num_consumer_groups": 0,
24928
+ "reg_dec_producer": 0,
24929
+ "reg_inc_consumer": 0,
24930
+ "maxnreg": null,
24931
+ "pre_hook": null
24932
+ },
24933
+ "(21, 7889449, 6, 8, 'torch.float32', 'torch.uint32', 'torch.float32', 'torch.float32')": {
24934
+ "kwargs": {
24935
+ "BM": 16,
24936
+ "BK": 8
24937
+ },
24938
+ "num_warps": 2,
24939
+ "num_ctas": 1,
24940
+ "num_stages": 2,
24941
+ "num_buffers_warp_spec": 0,
24942
+ "num_consumer_groups": 0,
24943
+ "reg_dec_producer": 0,
24944
+ "reg_inc_consumer": 0,
24945
+ "maxnreg": null,
24946
+ "pre_hook": null
24947
  }
24948
  },
24949
  "flex_gemm.kernels.triton.spconv.sparse_submanifold_conv_bwd_implicit_gemm.sparse_submanifold_conv_bwd_input_implicit_gemm_kernel": {