saliacoel commited on
Commit
71c89c1
·
verified ·
1 Parent(s): d828181

Update Inspyrenet_Rembg2.py

Browse files
Files changed (1) hide show
  1. Inspyrenet_Rembg2.py +285 -508
Inspyrenet_Rembg2.py CHANGED
@@ -1,20 +1,3 @@
1
- # Inspyrenet_Rembg2.py
2
- # - Keeps InspyrenetRembg2 + InspyrenetRembg3
3
- # - REMOVES InspyrenetRembg4 (do not include)
4
- # - ADDS:
5
- # Load_Inspyrenet_Global
6
- # Remove_Inspyrenet_Gobal
7
- # Run_InspyrenetRembg_Global
8
- #
9
- # Design:
10
- # - One process-wide Remover singleton per (mode, jit, ckpt)
11
- # - Construct Remover on CPU (safe). You explicitly move it to GPU via Load node.
12
- # - Run node:
13
- # * ensures it is on desired device (auto/cuda/cpu)
14
- # * if OOM: evicts VRAM (smallest-first or comfy_default, etc.) and retries
15
- # * if still OOM: optionally falls back to CPU
16
- # * NEVER crashes on OOM: returns original image (pass-through) as last resort
17
-
18
  from __future__ import annotations
19
 
20
  from PIL import Image
@@ -22,15 +5,15 @@ import os
22
  import urllib.request
23
  import gc
24
  import threading
25
- from typing import Dict, Tuple, Optional, Any
26
- from contextlib import nullcontext
27
 
28
  import torch
29
  import numpy as np
30
  from transparent_background import Remover
31
  from tqdm import tqdm
32
 
33
- # Optional: ComfyUI memory manager (present when running inside ComfyUI)
 
34
  try:
35
  import comfy.model_management as comfy_mm
36
  except Exception:
@@ -42,11 +25,6 @@ CKPT_URL = "https://huggingface.co/saliacoel/x/resolve/main/ckpt_base.pth"
42
 
43
 
44
  def _ensure_ckpt_base():
45
- """
46
- 1) Check /root/.transparent-background/ckpt_base.pth
47
- - if exists: do nothing
48
- - else: download from CKPT_URL
49
- """
50
  try:
51
  if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
52
  return
@@ -83,7 +61,6 @@ def _ensure_ckpt_base():
83
  f.write(chunk)
84
 
85
  os.replace(tmp_path, CKPT_PATH)
86
-
87
  finally:
88
  if os.path.isfile(tmp_path):
89
  try:
@@ -92,10 +69,7 @@ def _ensure_ckpt_base():
92
  pass
93
 
94
 
95
- # -----------------------------------------------------------------------------
96
- # Conversions
97
- # -----------------------------------------------------------------------------
98
-
99
  def tensor2pil(image: torch.Tensor) -> Image.Image:
100
  arr = image.detach().cpu().numpy()
101
  if arr.ndim == 4 and arr.shape[0] == 1:
@@ -104,18 +78,12 @@ def tensor2pil(image: torch.Tensor) -> Image.Image:
104
  return Image.fromarray(arr)
105
 
106
 
 
107
  def pil2tensor(image: Image.Image) -> torch.Tensor:
108
  return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
109
 
110
 
111
  def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
112
- """
113
- If input is RGBA:
114
- - alpha composite over WHITE background
115
- - convert to RGB (drop alpha)
116
- If input is RGB:
117
- - carry on
118
- """
119
  if pil_img.mode == "RGBA":
120
  bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
121
  composited = Image.alpha_composite(bg, pil_img)
@@ -127,19 +95,28 @@ def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
127
  return pil_img
128
 
129
 
130
- def _force_rgba_passthrough(pil_img: Image.Image) -> Image.Image:
131
  """
132
- Make sure we can return a sane IMAGE if everything fails.
133
- This is a PASS-THROUGH fallback (NOT white).
134
  """
135
- if pil_img.mode == "RGBA":
136
- return pil_img
137
- return pil_img.convert("RGBA")
 
138
 
139
 
140
- # -----------------------------------------------------------------------------
141
- # OOM + CUDA cleanup
142
- # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
143
 
144
  def _is_oom_error(e: BaseException) -> bool:
145
  oom_cuda_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None)
@@ -194,88 +171,72 @@ def _comfy_soft_empty_cache() -> None:
194
  pass
195
 
196
 
197
- def _bytes_from_mb(mb: int) -> int:
198
- return int(mb) * 1024 * 1024
199
-
200
-
201
- def _get_free_vram_bytes_best_effort() -> Optional[int]:
202
  """
203
- Prefer ComfyUI's free memory estimate (it includes torch-reserved-but-free blocks),
204
- otherwise fall back to torch.cuda.mem_get_info.
205
  """
206
- if comfy_mm is not None and hasattr(comfy_mm, "get_free_memory"):
207
  try:
208
- return int(comfy_mm.get_free_memory())
 
 
 
209
  except Exception:
210
  pass
211
 
212
  if torch.cuda.is_available():
213
- try:
214
- free_b, _total_b = torch.cuda.mem_get_info()
215
- return int(free_b)
216
- except Exception:
217
- pass
218
-
219
- return None
220
 
221
 
222
- # -----------------------------------------------------------------------------
223
- # ComfyUI model eviction (smallest-first, comfy_default, unload_all)
224
- # -----------------------------------------------------------------------------
225
-
226
- _GLOBAL_VRAM_EVICT_LOCK = threading.Lock()
227
-
228
-
229
- def _comfy_unload_all_models() -> None:
230
- if comfy_mm is None:
231
- return
232
- if hasattr(comfy_mm, "unload_all_models"):
233
  try:
234
- comfy_mm.unload_all_models()
 
235
  except Exception:
236
  pass
237
- _comfy_soft_empty_cache()
238
- _cuda_soft_cleanup()
239
 
240
 
241
- def _comfy_free_memory_to_target(target_free_bytes: int) -> None:
242
- if comfy_mm is None:
243
- return
244
- if not hasattr(comfy_mm, "free_memory") or not hasattr(comfy_mm, "get_torch_device"):
245
- return
246
  try:
247
- comfy_mm.free_memory(int(target_free_bytes), comfy_mm.get_torch_device())
 
 
248
  except Exception:
249
- pass
250
- _comfy_soft_empty_cache()
251
- _cuda_soft_cleanup()
252
 
253
 
254
- def _comfy_unload_smallest_model_once() -> bool:
255
  """
256
- Unload exactly ONE smallest model tracked by ComfyUI (best-effort).
257
- Returns True if we unloaded something.
 
258
  """
259
  if comfy_mm is None:
260
  return False
261
- if not hasattr(comfy_mm, "current_loaded_models") or not hasattr(comfy_mm, "get_torch_device"):
262
  return False
263
 
264
  try:
265
- dev = comfy_mm.get_torch_device()
266
  except Exception:
267
- dev = None
268
 
269
- loaded = []
270
  try:
271
- for lm in list(getattr(comfy_mm, "current_loaded_models", [])):
272
  try:
273
- if dev is not None and getattr(lm, "device", None) != dev:
274
- continue
275
- if hasattr(lm, "is_dead") and callable(lm.is_dead) and lm.is_dead():
276
  continue
277
 
278
- mem = 0
279
  mem_fn = getattr(lm, "model_loaded_memory", None)
280
  if callable(mem_fn):
281
  mem = int(mem_fn())
@@ -283,21 +244,20 @@ def _comfy_unload_smallest_model_once() -> bool:
283
  mem = int(getattr(lm, "loaded_memory", 0) or 0)
284
 
285
  if mem > 0:
286
- loaded.append((mem, lm))
287
  except Exception:
288
  continue
289
  except Exception:
290
  return False
291
 
292
- if not loaded:
293
  return False
294
 
295
- loaded.sort(key=lambda x: x[0]) # smallest first
296
- _mem, smallest = loaded[0]
297
 
298
- # Unload it
299
  try:
300
- unload_fn = getattr(smallest, "model_unload", None)
301
  if callable(unload_fn):
302
  try:
303
  unload_fn(unpatch_weights=True)
@@ -306,7 +266,7 @@ def _comfy_unload_smallest_model_once() -> bool:
306
  except Exception:
307
  pass
308
 
309
- # Ask ComfyUI to clean up bookkeeping (if available)
310
  try:
311
  cleanup = getattr(comfy_mm, "cleanup_models", None)
312
  if callable(cleanup):
@@ -319,94 +279,35 @@ def _comfy_unload_smallest_model_once() -> bool:
319
  return True
320
 
321
 
322
- def _evict_until_free(target_free_bytes: int, policy: str) -> None:
323
- """
324
- policy:
325
- - "smallest_first": unload smallest ComfyUI model repeatedly until free>=target or nothing left
326
- - "comfy_default": comfy_mm.free_memory(target, device)
327
- - "unload_all": unload_all_models()
328
- - "none": do nothing
329
- """
330
- if policy == "none":
331
  return
332
-
333
- with _GLOBAL_VRAM_EVICT_LOCK:
334
- if policy == "unload_all":
335
- _comfy_unload_all_models()
336
- return
337
-
338
- if policy == "comfy_default":
339
- _comfy_free_memory_to_target(target_free_bytes)
340
- return
341
-
342
- if policy == "smallest_first":
343
- # Loop until enough memory or no more models to unload
344
- for _ in range(256):
345
- free_b = _get_free_vram_bytes_best_effort()
346
- if free_b is None or free_b >= target_free_bytes:
347
- break
348
- if not _comfy_unload_smallest_model_once():
349
- break
350
- return
351
 
352
 
353
  # -----------------------------------------------------------------------------
354
- # Remover singleton cache (shared by all nodes)
355
  # -----------------------------------------------------------------------------
356
 
357
- # key = (mode, jit, ckpt)
358
- _RemKey = Tuple[str, bool, Optional[str]]
359
-
360
- _REMOVER_CACHE: Dict[_RemKey, Remover] = {}
361
- _REMOVER_RUN_LOCKS: Dict[_RemKey, threading.Lock] = {}
362
- _REMOVER_VRAM_BYTES: Dict[_RemKey, int] = {} # approximate model VRAM residency cost (bytes)
363
  _CACHE_LOCK = threading.Lock()
364
 
365
 
366
- def _construct_remover(mode: str, jit: bool, device: str, ckpt: Optional[str]) -> Remover:
367
- """
368
- Construct Remover with best-effort compatibility across transparent_background versions.
369
- """
370
- # Try current signature: Remover(mode=..., jit=..., device=..., ckpt=...)
371
- kwargs: Dict[str, Any] = {"jit": jit, "device": device}
372
- if mode:
373
- kwargs["mode"] = mode
374
- if ckpt:
375
- kwargs["ckpt"] = ckpt
376
-
377
- try:
378
- return Remover(**kwargs)
379
- except TypeError:
380
- pass
381
-
382
- # Try without "mode" (some variants)
383
- kwargs2: Dict[str, Any] = {"jit": jit, "device": device}
384
- if ckpt:
385
- kwargs2["ckpt"] = ckpt
386
- try:
387
- return Remover(**kwargs2)
388
- except TypeError:
389
- pass
390
-
391
- # Try legacy "fast=" API
392
- kwargs3: Dict[str, Any] = {"jit": jit, "device": device, "fast": (mode == "fast")}
393
- if ckpt:
394
- kwargs3["ckpt"] = ckpt
395
- return Remover(**kwargs3)
396
-
397
-
398
- def _get_remover(mode: str = "base", jit: bool = False, ckpt: Optional[str] = None) -> tuple[Remover, threading.Lock, _RemKey]:
399
- """
400
- Cached Remover per (mode, jit, ckpt). Constructed on CPU by default to avoid VRAM OOM.
401
- """
402
- key: _RemKey = (mode, jit, ckpt)
403
-
404
  with _CACHE_LOCK:
405
  inst = _REMOVER_CACHE.get(key)
406
  if inst is None:
407
  _ensure_ckpt_base()
408
  try:
409
- inst = _construct_remover(mode=mode, jit=jit, device="cpu", ckpt=ckpt)
410
  except BaseException as e:
411
  if _is_oom_error(e):
412
  _cuda_soft_cleanup()
@@ -418,205 +319,204 @@ def _get_remover(mode: str = "base", jit: bool = False, ckpt: Optional[str] = No
418
  run_lock = threading.Lock()
419
  _REMOVER_RUN_LOCKS[key] = run_lock
420
 
421
- return inst, run_lock, key
 
422
 
 
 
 
423
 
424
- def _get_target_device_str(device_choice: str) -> str:
 
 
 
 
 
 
 
425
  """
426
- device_choice:
427
- - "auto": ComfyUI device if available, else "cuda:0" if cuda else "cpu"
428
- - "cuda": "cuda:0"
429
- - "cpu": "cpu"
430
  """
431
- if device_choice == "cpu":
432
- return "cpu"
433
- if device_choice == "cuda":
434
- return "cuda:0"
435
 
436
- # auto:
437
- if comfy_mm is not None and hasattr(comfy_mm, "get_torch_device"):
 
438
  try:
439
- return str(comfy_mm.get_torch_device())
440
  except Exception:
441
  pass
 
 
 
442
 
443
- if torch.cuda.is_available():
444
- return "cuda:0"
445
- return "cpu"
446
-
447
-
448
- def _move_remover_to_device(
449
- remover: Remover,
450
- *,
451
- key: _RemKey,
452
- device_str: str,
453
- min_free_vram_mb: int,
454
- extra_vram_mb: int,
455
- unload_policy: str,
456
- measure_model_vram: bool,
457
- ) -> tuple[bool, Optional[int]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  """
459
- Move remover.model to device_str (best effort).
460
- For CUDA: evict models first if free VRAM < required.
461
- Returns (ok, measured_model_vram_bytes_or_None)
462
  """
463
- # Already on device?
464
- if str(getattr(remover, "device", "")) == device_str:
465
- return True, _REMOVER_VRAM_BYTES.get(key)
466
 
467
- # CPU path: always try
468
- if device_str.startswith("cpu") or device_str.startswith("mps") or device_str.startswith("xpu") or device_str.startswith("npu") or device_str.startswith("mlu"):
469
- try:
470
- remover.model = remover.model.to(device_str)
471
- remover.device = device_str
472
- _cuda_soft_cleanup()
473
- return True, _REMOVER_VRAM_BYTES.get(key)
474
- except BaseException as e:
475
- if _is_oom_error(e):
476
- _cuda_soft_cleanup()
477
- return False, _REMOVER_VRAM_BYTES.get(key)
478
 
479
- # CUDA path
480
- required_bytes = _bytes_from_mb(min_free_vram_mb)
 
481
 
482
- # If we have a measured model size, require that + extra_vram_mb instead
483
- measured = _REMOVER_VRAM_BYTES.get(key)
484
- if measured is not None and measured > 0:
485
- required_bytes = max(required_bytes, int(measured) + _bytes_from_mb(extra_vram_mb))
 
486
 
487
- free_b = _get_free_vram_bytes_best_effort()
488
- if free_b is not None and free_b < required_bytes:
489
- _evict_until_free(required_bytes, unload_policy)
490
 
491
- before = _get_free_vram_bytes_best_effort() if measure_model_vram else None
492
 
493
- try:
494
- remover.model = remover.model.to(device_str)
495
- remover.device = device_str
 
 
 
 
496
 
497
- _comfy_soft_empty_cache()
498
- _cuda_soft_cleanup()
499
 
500
- if measure_model_vram:
501
- after = _get_free_vram_bytes_best_effort()
502
- if before is not None and after is not None:
503
- used = max(0, int(before) - int(after))
504
- if used > 0:
505
- _REMOVER_VRAM_BYTES[key] = used
506
 
507
- return True, _REMOVER_VRAM_BYTES.get(key)
508
 
509
- except BaseException as e:
510
- if _is_oom_error(e):
 
 
511
  _cuda_soft_cleanup()
512
- return False, _REMOVER_VRAM_BYTES.get(key)
513
 
 
 
 
514
 
515
- def _remover_process_basic(remover: Remover, run_lock: threading.Lock, pil_img: Image.Image, out_type: str) -> Image.Image:
516
- """
517
- Basic safe call. May still OOM if on GPU and VRAM is tight.
518
- """
519
- with run_lock:
520
- with torch.inference_mode():
521
- return remover.process(pil_img, type=out_type)
522
-
523
-
524
- def _remover_process_no_crash(
525
- remover: Remover,
526
- run_lock: threading.Lock,
527
- *,
528
- key: _RemKey,
529
- pil_img_rgb: Image.Image,
530
- out_type: str,
531
- device_choice: str,
532
- min_free_vram_mb: int,
533
- extra_vram_mb: int,
534
- unload_policy: str,
535
- allow_cpu_fallback: bool,
536
- use_fp16_autocast: bool,
537
- ) -> Optional[Image.Image]:
538
  """
539
- Try inference:
540
- 1) run once
541
- 2) on OOM: evict VRAM (policy) and retry once
542
- 3) on OOM: optional CPU fallback and retry once
543
- Returns PIL on success, None if still OOM.
544
  """
545
- def _amp_ctx():
546
- dev = str(getattr(remover, "device", "") or "")
547
- if use_fp16_autocast and torch.cuda.is_available() and dev.startswith("cuda"):
548
- try:
549
- return torch.autocast("cuda", dtype=torch.float16)
550
- except Exception:
551
- return nullcontext()
552
- return nullcontext()
553
 
554
- # First try
 
 
 
555
  try:
556
- with run_lock:
557
  with torch.inference_mode():
558
- with _amp_ctx():
559
- return remover.process(pil_img_rgb, type=out_type)
 
 
 
560
  except BaseException as e:
561
  if not _is_oom_error(e):
562
  raise
563
 
564
- # OOM: cleanup + evict + retry
 
565
  _cuda_soft_cleanup()
566
- required_bytes = _bytes_from_mb(min_free_vram_mb)
567
- measured = _REMOVER_VRAM_BYTES.get(key)
568
- if measured is not None and measured > 0:
569
- required_bytes = max(required_bytes, int(measured) + _bytes_from_mb(extra_vram_mb))
570
-
571
- _evict_until_free(required_bytes, unload_policy)
572
 
573
  try:
574
- with run_lock:
575
  with torch.inference_mode():
576
- with _amp_ctx():
577
- return remover.process(pil_img_rgb, type=out_type)
 
 
578
  except BaseException as e:
579
  if not _is_oom_error(e):
580
  raise
581
 
582
- if not allow_cpu_fallback:
583
- return None
584
 
585
- # CPU fallback
586
  try:
587
- ok, _ = _move_remover_to_device(
588
- remover,
589
- key=key,
590
- device_str="cpu",
591
- min_free_vram_mb=min_free_vram_mb,
592
- extra_vram_mb=extra_vram_mb,
593
- unload_policy="none",
594
- measure_model_vram=False,
595
- )
596
- if not ok:
597
- return None
598
-
599
- _cuda_soft_cleanup()
600
-
601
- with run_lock:
602
  with torch.inference_mode():
603
- return remover.process(pil_img_rgb, type=out_type)
 
 
 
604
  except BaseException as e:
605
  if not _is_oom_error(e):
606
  raise
607
- return None
 
 
 
 
 
 
 
 
 
 
 
 
608
 
609
 
610
  # -----------------------------------------------------------------------------
611
- # Existing Nodes: InspyrenetRembg2 / InspyrenetRembg3 (kept)
612
  # -----------------------------------------------------------------------------
613
 
614
  class InspyrenetRembg2:
615
- """
616
- Kept behavior/output.
617
- Uses cached Remover (constructed on CPU by default in this file).
618
- If you want it on GPU: call Load_Inspyrenet_Global first with matching (mode/jit/ckpt).
619
- """
620
  def __init__(self):
621
  pass
622
 
@@ -625,7 +525,7 @@ class InspyrenetRembg2:
625
  return {
626
  "required": {
627
  "image": ("IMAGE",),
628
- "torchscript_jit": (["default", "on"],),
629
  },
630
  }
631
 
@@ -635,18 +535,21 @@ class InspyrenetRembg2:
635
 
636
  def remove_background(self, image, torchscript_jit):
637
  jit = (torchscript_jit != "default")
638
- remover, run_lock, _key = _get_remover(mode="base", jit=jit, ckpt=None)
639
 
640
  img_list = []
641
  for img in tqdm(image, "Inspyrenet Rembg2"):
642
  pil_in = tensor2pil(img)
643
  try:
644
- mid = _remover_process_basic(remover, run_lock, pil_in, out_type="rgba")
 
 
645
  except BaseException as e:
646
  if _is_oom_error(e):
647
  _cuda_soft_cleanup()
648
  raise RuntimeError("InspyrenetRembg2: CUDA out of memory.") from e
649
  raise
 
650
  out = pil2tensor(mid)
651
  img_list.append(out)
652
  del pil_in, mid, out
@@ -657,11 +560,6 @@ class InspyrenetRembg2:
657
 
658
 
659
  class InspyrenetRembg3:
660
- """
661
- Kept behavior/output.
662
- Uses cached Remover (constructed on CPU by default in this file).
663
- If you want it on GPU: call Load_Inspyrenet_Global first with matching (mode/jit/ckpt).
664
- """
665
  def __init__(self):
666
  pass
667
 
@@ -678,7 +576,7 @@ class InspyrenetRembg3:
678
  CATEGORY = "image"
679
 
680
  def remove_background(self, image):
681
- remover, run_lock, _key = _get_remover(mode="base", jit=False, ckpt=None)
682
 
683
  img_list = []
684
  for img in tqdm(image, "Inspyrenet Rembg3"):
@@ -686,7 +584,9 @@ class InspyrenetRembg3:
686
  pil_rgb = _rgba_to_rgb_on_white(pil_in)
687
 
688
  try:
689
- mid = _remover_process_basic(remover, run_lock, pil_rgb, out_type="rgba")
 
 
690
  except BaseException as e:
691
  if _is_oom_error(e):
692
  _cuda_soft_cleanup()
@@ -695,7 +595,6 @@ class InspyrenetRembg3:
695
 
696
  out = pil2tensor(mid)
697
  img_list.append(out)
698
-
699
  del pil_in, pil_rgb, mid, out
700
 
701
  img_stack = torch.cat(img_list, dim=0)
@@ -703,77 +602,36 @@ class InspyrenetRembg3:
703
 
704
 
705
  # -----------------------------------------------------------------------------
706
- # New Nodes: Load / Remove / Run (Global)
707
  # -----------------------------------------------------------------------------
708
 
709
  class Load_Inspyrenet_Global:
710
  """
711
- Loads the global singleton Remover instance (per mode/jit/ckpt) and moves it to target device.
712
- Measures approximate VRAM delta (best-effort) for the model residency.
 
 
713
  """
714
  def __init__(self):
715
  pass
716
 
717
  @classmethod
718
  def INPUT_TYPES(s):
719
- return {
720
- "required": {
721
- "mode": (["base", "fast", "base-nightly"], {"default": "base"}),
722
- "torchscript_jit": (["off", "on"], {"default": "off"}),
723
- "device": (["auto", "cuda", "cpu"], {"default": "auto"}),
724
 
725
- # target VRAM policy
726
- "min_free_vram_mb": ("INT", {"default": 4096, "min": 0, "max": 65536, "step": 256}),
727
- "extra_vram_mb": ("INT", {"default": 1024, "min": 0, "max": 65536, "step": 256}),
728
- "unload_policy": (["smallest_first", "comfy_default", "unload_all", "none"], {"default": "smallest_first"}),
729
-
730
- "measure_model_vram": (["yes", "no"], {"default": "yes"}),
731
- },
732
- "optional": {
733
- "ckpt_override": ("STRING", {"default": ""}),
734
- },
735
- }
736
-
737
- RETURN_TYPES = ("BOOLEAN", "INT", "STRING")
738
  FUNCTION = "load"
739
  CATEGORY = "image"
740
 
741
- def load(
742
- self,
743
- mode: str,
744
- torchscript_jit: str,
745
- device: str,
746
- min_free_vram_mb: int,
747
- extra_vram_mb: int,
748
- unload_policy: str,
749
- measure_model_vram: str,
750
- ckpt_override: str = "",
751
- ):
752
- jit = (torchscript_jit == "on")
753
- ckpt = ckpt_override.strip() or None
754
-
755
- remover, _run_lock, key = _get_remover(mode=mode, jit=jit, ckpt=ckpt)
756
-
757
- device_str = _get_target_device_str(device)
758
- ok, measured_bytes = _move_remover_to_device(
759
- remover,
760
- key=key,
761
- device_str=device_str,
762
- min_free_vram_mb=int(min_free_vram_mb),
763
- extra_vram_mb=int(extra_vram_mb),
764
- unload_policy=unload_policy,
765
- measure_model_vram=(measure_model_vram == "yes"),
766
- )
767
-
768
- measured_mb = int((measured_bytes or 0) / (1024 * 1024))
769
- status = f"Load_Inspyrenet_Global: ok={ok}, mode={mode}, jit={jit}, device={getattr(remover,'device',None)}, measured_model_vram_mb={measured_mb}"
770
- return (bool(ok), measured_mb, status)
771
-
772
-
773
- class Remove_Inspyrenet_Gobal:
774
  """
775
- Offloads the global singleton Remover to CPU (keeps instance), or deletes it (forces re-create later).
776
- Optionally unloads all ComfyUI models too.
777
  """
778
  def __init__(self):
779
  pass
@@ -782,60 +640,44 @@ class Remove_Inspyrenet_Gobal:
782
  def INPUT_TYPES(s):
783
  return {
784
  "required": {
785
- "mode": (["base", "fast", "base-nightly"], {"default": "base"}),
786
- "torchscript_jit": (["off", "on"], {"default": "off"}),
787
- "action": (["offload_to_cpu", "delete_instance"], {"default": "offload_to_cpu"}),
788
- "also_unload_all_models": (["no", "yes"], {"default": "no"}),
789
- },
790
- "optional": {
791
- "ckpt_override": ("STRING", {"default": ""}),
792
- },
793
  }
794
 
795
- RETURN_TYPES = ("BOOLEAN", "STRING")
796
  FUNCTION = "remove"
797
  CATEGORY = "image"
798
 
799
- def remove(self, mode, torchscript_jit, action, also_unload_all_models, ckpt_override=""):
800
- jit = (torchscript_jit == "on")
801
- ckpt = ckpt_override.strip() or None
802
-
803
- remover, _run_lock, key = _get_remover(mode=mode, jit=jit, ckpt=ckpt)
804
-
805
- # Offload remover itself to CPU
806
- try:
807
- remover.model = remover.model.to("cpu")
808
- remover.device = "cpu"
809
- except Exception:
810
- pass
811
 
812
- # Optionally delete instance from cache
813
- if action == "delete_instance":
814
- with _CACHE_LOCK:
815
- try:
816
- _REMOVER_CACHE.pop(key, None)
817
- _REMOVER_RUN_LOCKS.pop(key, None)
818
- _REMOVER_VRAM_BYTES.pop(key, None)
819
- except Exception:
820
- pass
821
-
822
- if also_unload_all_models == "yes":
823
- _comfy_unload_all_models()
 
 
 
824
 
825
- _comfy_soft_empty_cache()
826
  _cuda_soft_cleanup()
827
-
828
- status = f"Remove_Inspyrenet_Gobal: action={action}, offloaded_device={getattr(remover,'device',None)}, also_unload_all_models={also_unload_all_models}"
829
- return (True, status)
830
 
831
 
832
  class Run_InspyrenetRembg_Global:
833
  """
834
- Runs global Remover with OOM avoidance:
835
- - tries on requested device (auto/cuda/cpu)
836
- - on OOM: evicts models (policy) and retries
837
- - optional CPU fallback
838
- - NEVER crashes on OOM: last resort returns input image (pass-through RGBA)
839
  """
840
  def __init__(self):
841
  pass
@@ -845,107 +687,42 @@ class Run_InspyrenetRembg_Global:
845
  return {
846
  "required": {
847
  "image": ("IMAGE",),
848
- },
849
- "optional": {
850
- "mode": (["base", "fast", "base-nightly"], {"default": "base"}),
851
- "torchscript_jit": (["off", "on"], {"default": "off"}),
852
- "device": (["auto", "cuda", "cpu"], {"default": "auto"}),
853
-
854
- "min_free_vram_mb": ("INT", {"default": 4096, "min": 0, "max": 65536, "step": 256}),
855
- "extra_vram_mb": ("INT", {"default": 1024, "min": 0, "max": 65536, "step": 256}),
856
- "unload_policy": (["smallest_first", "comfy_default", "unload_all", "none"], {"default": "smallest_first"}),
857
-
858
- "allow_cpu_fallback": (["yes", "no"], {"default": "yes"}),
859
- "use_fp16_autocast": (["yes", "no"], {"default": "yes"}),
860
-
861
- "ckpt_override": ("STRING", {"default": ""}),
862
- },
863
  }
864
 
865
  RETURN_TYPES = ("IMAGE",)
866
  FUNCTION = "remove_background"
867
  CATEGORY = "image"
868
 
869
- def remove_background(
870
- self,
871
- image,
872
- mode="base",
873
- torchscript_jit="off",
874
- device="auto",
875
- min_free_vram_mb=4096,
876
- extra_vram_mb=1024,
877
- unload_policy="smallest_first",
878
- allow_cpu_fallback="yes",
879
- use_fp16_autocast="yes",
880
- ckpt_override="",
881
- ):
882
- jit = (torchscript_jit == "on")
883
- ckpt = ckpt_override.strip() or None
884
-
885
- remover, run_lock, key = _get_remover(mode=mode, jit=jit, ckpt=ckpt)
886
-
887
- # Ensure desired device (best-effort)
888
- device_str = _get_target_device_str(device)
889
- _move_remover_to_device(
890
- remover,
891
- key=key,
892
- device_str=device_str,
893
- min_free_vram_mb=int(min_free_vram_mb),
894
- extra_vram_mb=int(extra_vram_mb),
895
- unload_policy=unload_policy,
896
- measure_model_vram=False,
897
- )
898
-
899
- allow_cpu = (allow_cpu_fallback == "yes")
900
- fp16_amp = (use_fp16_autocast == "yes")
901
 
902
  img_list = []
903
  for img in tqdm(image, "Run InspyrenetRembg Global"):
904
  pil_in = tensor2pil(img)
905
 
906
- # Always keep a pass-through fallback (NOT white)
907
- fallback = _force_rgba_passthrough(pil_in)
908
 
909
- # Model input: RGB (your prior behavior uses white composite if RGBA)
910
  pil_rgb = _rgba_to_rgb_on_white(pil_in)
911
 
912
- out_pil = _remover_process_no_crash(
913
- remover,
914
- run_lock,
915
- key=key,
916
- pil_img_rgb=pil_rgb,
917
- out_type="rgba",
918
- device_choice=device,
919
- min_free_vram_mb=int(min_free_vram_mb),
920
- extra_vram_mb=int(extra_vram_mb),
921
- unload_policy=unload_policy,
922
- allow_cpu_fallback=allow_cpu,
923
- use_fp16_autocast=fp16_amp,
924
- )
925
-
926
- if out_pil is None:
927
- # Absolute last resort: return input pixels (pass-through), do not crash.
928
- out_pil = fallback
929
-
930
  out = pil2tensor(out_pil)
931
  img_list.append(out)
932
 
933
- del pil_in, pil_rgb, fallback, out_pil, out
934
 
935
  img_stack = torch.cat(img_list, dim=0)
936
  return (img_stack,)
937
 
938
 
939
- # -----------------------------------------------------------------------------
940
- # Node mappings
941
- # -----------------------------------------------------------------------------
942
-
943
  NODE_CLASS_MAPPINGS = {
944
  "InspyrenetRembg2": InspyrenetRembg2,
945
  "InspyrenetRembg3": InspyrenetRembg3,
946
 
947
  "Load_Inspyrenet_Global": Load_Inspyrenet_Global,
948
- "Remove_Inspyrenet_Gobal": Remove_Inspyrenet_Gobal,
949
  "Run_InspyrenetRembg_Global": Run_InspyrenetRembg_Global,
950
  }
951
 
@@ -954,6 +731,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
954
  "InspyrenetRembg3": "Inspyrenet Rembg3",
955
 
956
  "Load_Inspyrenet_Global": "Load Inspyrenet Global",
957
- "Remove_Inspyrenet_Gobal": "Remove Inspyrenet Gobal",
958
  "Run_InspyrenetRembg_Global": "Run InspyrenetRembg Global",
959
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from PIL import Image
 
5
  import urllib.request
6
  import gc
7
  import threading
8
+ from typing import Dict, Tuple, Optional
 
9
 
10
  import torch
11
  import numpy as np
12
  from transparent_background import Remover
13
  from tqdm import tqdm
14
 
15
+
16
+ # Optional: ComfyUI memory manager (present inside ComfyUI)
17
  try:
18
  import comfy.model_management as comfy_mm
19
  except Exception:
 
25
 
26
 
27
  def _ensure_ckpt_base():
 
 
 
 
 
28
  try:
29
  if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
30
  return
 
61
  f.write(chunk)
62
 
63
  os.replace(tmp_path, CKPT_PATH)
 
64
  finally:
65
  if os.path.isfile(tmp_path):
66
  try:
 
69
  pass
70
 
71
 
72
+ # Tensor to PIL
 
 
 
73
  def tensor2pil(image: torch.Tensor) -> Image.Image:
74
  arr = image.detach().cpu().numpy()
75
  if arr.ndim == 4 and arr.shape[0] == 1:
 
78
  return Image.fromarray(arr)
79
 
80
 
81
+ # Convert PIL to Tensor
82
  def pil2tensor(image: Image.Image) -> torch.Tensor:
83
  return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
84
 
85
 
86
  def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
 
 
 
 
 
 
 
87
  if pil_img.mode == "RGBA":
88
  bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
89
  composited = Image.alpha_composite(bg, pil_img)
 
95
  return pil_img
96
 
97
 
98
+ def _force_rgba_opaque(pil_img: Image.Image) -> Image.Image:
99
  """
100
+ Opaque RGBA fallback (alpha=255), so you never get an "invisible" output.
 
101
  """
102
+ rgba = pil_img.convert("RGBA")
103
+ r, g, b, _a = rgba.split()
104
+ a = Image.new("L", rgba.size, 255)
105
+ return Image.merge("RGBA", (r, g, b, a))
106
 
107
 
108
+ def _alpha_is_all_zero(pil_img: Image.Image) -> bool:
109
+ """
110
+ True if RGBA image alpha channel is entirely 0.
111
+ """
112
+ if pil_img.mode != "RGBA":
113
+ return False
114
+ try:
115
+ extrema = pil_img.getextrema() # ((min,max),(min,max),(min,max),(min,max))
116
+ return extrema[3][1] == 0
117
+ except Exception:
118
+ return False
119
+
120
 
121
  def _is_oom_error(e: BaseException) -> bool:
122
  oom_cuda_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None)
 
171
  pass
172
 
173
 
174
+ def _get_comfy_torch_device() -> torch.device:
 
 
 
 
175
  """
176
+ Always prefer ComfyUI's chosen device.
 
177
  """
178
+ if comfy_mm is not None and hasattr(comfy_mm, "get_torch_device"):
179
  try:
180
+ d = comfy_mm.get_torch_device()
181
+ if isinstance(d, torch.device):
182
+ return d
183
+ return torch.device(str(d))
184
  except Exception:
185
  pass
186
 
187
  if torch.cuda.is_available():
188
+ return torch.device("cuda:0")
189
+ return torch.device("cpu")
 
 
 
 
 
190
 
191
 
192
+ def _set_current_cuda_device(dev: torch.device) -> None:
193
+ """
194
+ Make sure mem_get_info() measurements are on the same device ComfyUI uses.
195
+ """
196
+ if dev.type == "cuda":
 
 
 
 
 
 
197
  try:
198
+ if dev.index is not None:
199
+ torch.cuda.set_device(dev.index)
200
  except Exception:
201
  pass
 
 
202
 
203
 
204
+ def _cuda_free_bytes_on(dev: torch.device) -> Optional[int]:
205
+ if dev.type != "cuda" or not torch.cuda.is_available():
206
+ return None
 
 
207
  try:
208
+ _set_current_cuda_device(dev)
209
+ free_b, _total_b = torch.cuda.mem_get_info()
210
+ return int(free_b)
211
  except Exception:
212
+ return None
 
 
213
 
214
 
215
+ def _comfy_unload_one_smallest_model() -> bool:
216
  """
217
+ Best-effort "smallest-first" eviction of one ComfyUI-tracked loaded model.
218
+
219
+ If ComfyUI internals differ, this may do nothing (and we fall back to unload_all_models()).
220
  """
221
  if comfy_mm is None:
222
  return False
223
+ if not hasattr(comfy_mm, "current_loaded_models"):
224
  return False
225
 
226
  try:
227
+ cur_dev = _get_comfy_torch_device()
228
  except Exception:
229
+ cur_dev = None
230
 
231
+ models = []
232
  try:
233
+ for lm in list(comfy_mm.current_loaded_models):
234
  try:
235
+ # Prefer same device
236
+ lm_dev = getattr(lm, "device", None)
237
+ if cur_dev is not None and lm_dev is not None and str(lm_dev) != str(cur_dev):
238
  continue
239
 
 
240
  mem_fn = getattr(lm, "model_loaded_memory", None)
241
  if callable(mem_fn):
242
  mem = int(mem_fn())
 
244
  mem = int(getattr(lm, "loaded_memory", 0) or 0)
245
 
246
  if mem > 0:
247
+ models.append((mem, lm))
248
  except Exception:
249
  continue
250
  except Exception:
251
  return False
252
 
253
+ if not models:
254
  return False
255
 
256
+ models.sort(key=lambda x: x[0]) # smallest first
257
+ _mem, lm = models[0]
258
 
 
259
  try:
260
+ unload_fn = getattr(lm, "model_unload", None)
261
  if callable(unload_fn):
262
  try:
263
  unload_fn(unpatch_weights=True)
 
266
  except Exception:
267
  pass
268
 
269
+ # Cleanup hook if present
270
  try:
271
  cleanup = getattr(comfy_mm, "cleanup_models", None)
272
  if callable(cleanup):
 
279
  return True
280
 
281
 
282
+ def _comfy_unload_all_models() -> None:
283
+ if comfy_mm is None:
 
 
 
 
 
 
 
284
  return
285
+ if hasattr(comfy_mm, "unload_all_models"):
286
+ try:
287
+ comfy_mm.unload_all_models()
288
+ except Exception:
289
+ pass
290
+ _comfy_soft_empty_cache()
291
+ _cuda_soft_cleanup()
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  # -----------------------------------------------------------------------------
295
+ # Existing singleton cache for Rembg2/Rembg3 (your original)
296
  # -----------------------------------------------------------------------------
297
 
298
+ _REMOVER_CACHE: Dict[Tuple[bool], Remover] = {}
299
+ _REMOVER_RUN_LOCKS: Dict[Tuple[bool], threading.Lock] = {}
 
 
 
 
300
  _CACHE_LOCK = threading.Lock()
301
 
302
 
303
+ def _get_remover(jit: bool = False) -> tuple[Remover, threading.Lock]:
304
+ key = (jit,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  with _CACHE_LOCK:
306
  inst = _REMOVER_CACHE.get(key)
307
  if inst is None:
308
  _ensure_ckpt_base()
309
  try:
310
+ inst = Remover(jit=jit) if jit else Remover()
311
  except BaseException as e:
312
  if _is_oom_error(e):
313
  _cuda_soft_cleanup()
 
319
  run_lock = threading.Lock()
320
  _REMOVER_RUN_LOCKS[key] = run_lock
321
 
322
+ return inst, run_lock
323
+
324
 
325
+ # -----------------------------------------------------------------------------
326
+ # GLOBAL remover (for Load/Remove/Run Global nodes)
327
+ # -----------------------------------------------------------------------------
328
 
329
+ _GLOBAL_LOCK = threading.Lock()
330
+ _GLOBAL_RUN_LOCK = threading.Lock()
331
+ _GLOBAL_REMOVER: Optional[Remover] = None
332
+ _GLOBAL_ON_DEVICE: str = "cpu"
333
+ _GLOBAL_VRAM_DELTA_BYTES: int = 0
334
+
335
+
336
+ def _create_global_remover_cpu() -> Remover:
337
  """
338
+ Create the Remover configured like InspyrenetRembg3 (jit=False),
339
+ but *try* to force CPU init to avoid VRAM OOM during creation.
 
 
340
  """
341
+ _ensure_ckpt_base()
 
 
 
342
 
343
+ # Prefer constructing on CPU if supported by this library version.
344
+ try:
345
+ r = Remover(device="cpu") # type: ignore[arg-type]
346
  try:
347
+ r.device = "cpu"
348
  except Exception:
349
  pass
350
+ return r
351
+ except TypeError:
352
+ pass
353
 
354
+ # Fallback: construct default and immediately offload to CPU
355
+ r = Remover()
356
+ try:
357
+ if hasattr(r, "model"):
358
+ r.model = r.model.to("cpu")
359
+ r.device = "cpu"
360
+ except Exception:
361
+ pass
362
+ _cuda_soft_cleanup()
363
+ return r
364
+
365
+
366
+ def _get_global_remover() -> Remover:
367
+ global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE
368
+ with _GLOBAL_LOCK:
369
+ if _GLOBAL_REMOVER is None:
370
+ _GLOBAL_REMOVER = _create_global_remover_cpu()
371
+ _GLOBAL_ON_DEVICE = str(getattr(_GLOBAL_REMOVER, "device", "cpu"))
372
+ return _GLOBAL_REMOVER
373
+
374
+
375
+ def _move_global_to_cpu() -> None:
376
+ global _GLOBAL_ON_DEVICE
377
+ r = _get_global_remover()
378
+ try:
379
+ if hasattr(r, "model"):
380
+ r.model = r.model.to("cpu")
381
+ r.device = "cpu"
382
+ _GLOBAL_ON_DEVICE = "cpu"
383
+ except Exception:
384
+ pass
385
+ _cuda_soft_cleanup()
386
+
387
+
388
+ def _load_global_to_comfy_cuda_no_crash(max_evictions: int = 32) -> bool:
389
  """
390
+ Load the global remover into VRAM on ComfyUI's chosen CUDA device.
391
+ Never crashes on OOM: evicts smallest model first, then unload_all as last resort.
392
+ Also records a best-effort VRAM delta.
393
  """
394
+ global _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
 
 
395
 
396
+ r = _get_global_remover()
397
+ dev = _get_comfy_torch_device()
 
 
 
 
 
 
 
 
 
398
 
399
+ if dev.type != "cuda" or not torch.cuda.is_available():
400
+ _move_global_to_cpu()
401
+ return False
402
 
403
+ # Already on CUDA?
404
+ cur_dev = str(getattr(r, "device", "") or "")
405
+ if cur_dev.startswith("cuda"):
406
+ _GLOBAL_ON_DEVICE = cur_dev
407
+ return True
408
 
409
+ _set_current_cuda_device(dev)
 
 
410
 
411
+ free_before = _cuda_free_bytes_on(dev)
412
 
413
+ for _ in range(max_evictions + 1):
414
+ try:
415
+ # Move model to the SAME device ComfyUI uses
416
+ if hasattr(r, "model"):
417
+ r.model = r.model.to(dev)
418
+ r.device = str(dev)
419
+ _GLOBAL_ON_DEVICE = str(dev)
420
 
421
+ _comfy_soft_empty_cache()
422
+ _cuda_soft_cleanup()
423
 
424
+ free_after = _cuda_free_bytes_on(dev)
425
+ if free_before is not None and free_after is not None:
426
+ delta = max(0, int(free_before) - int(free_after))
427
+ if delta > 0:
428
+ _GLOBAL_VRAM_DELTA_BYTES = delta
 
429
 
430
+ return True
431
 
432
+ except BaseException as e:
433
+ if not _is_oom_error(e):
434
+ raise
435
+ _comfy_soft_empty_cache()
436
  _cuda_soft_cleanup()
 
437
 
438
+ # Evict ONE smallest model; if that fails, unload all.
439
+ if not _comfy_unload_one_smallest_model():
440
+ _comfy_unload_all_models()
441
 
442
+ # Could not load
443
+ _move_global_to_cpu()
444
+ return False
445
+
446
+
447
+ def _run_global_rgba_no_crash(pil_rgb: Image.Image, fallback_rgba: Image.Image) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  """
449
+ Run remover.process() (rgba output), matching InspyrenetRembg3 behavior.
450
+ On OOM: evict models and retry, then CPU fallback.
451
+ If output alpha is fully transparent, return fallback (prevents "invisible" output).
 
 
452
  """
453
+ r = _get_global_remover()
 
 
 
 
 
 
 
454
 
455
+ # Try to keep it on CUDA (Comfy device) if possible; do not crash if not.
456
+ _load_global_to_comfy_cuda_no_crash()
457
+
458
+ # Attempt 1: whatever device we're on (likely CUDA)
459
  try:
460
+ with _GLOBAL_RUN_LOCK:
461
  with torch.inference_mode():
462
+ out = r.process(pil_rgb, type="rgba")
463
+ if _alpha_is_all_zero(out):
464
+ # Treat as failure -> prevents invisible output
465
+ return fallback_rgba
466
+ return out
467
  except BaseException as e:
468
  if not _is_oom_error(e):
469
  raise
470
 
471
+ # OOM path: evict one smallest and retry (still on CUDA if we are)
472
+ _comfy_soft_empty_cache()
473
  _cuda_soft_cleanup()
474
+ _comfy_unload_one_smallest_model()
 
 
 
 
 
475
 
476
  try:
477
+ with _GLOBAL_RUN_LOCK:
478
  with torch.inference_mode():
479
+ out = r.process(pil_rgb, type="rgba")
480
+ if _alpha_is_all_zero(out):
481
+ return fallback_rgba
482
+ return out
483
  except BaseException as e:
484
  if not _is_oom_error(e):
485
  raise
486
 
487
+ # OOM again: unload all comfy models and retry once
488
+ _comfy_unload_all_models()
489
 
 
490
  try:
491
+ with _GLOBAL_RUN_LOCK:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  with torch.inference_mode():
493
+ out = r.process(pil_rgb, type="rgba")
494
+ if _alpha_is_all_zero(out):
495
+ return fallback_rgba
496
+ return out
497
  except BaseException as e:
498
  if not _is_oom_error(e):
499
  raise
500
+
501
+ # Final: CPU fallback
502
+ _move_global_to_cpu()
503
+ try:
504
+ with _GLOBAL_RUN_LOCK:
505
+ with torch.inference_mode():
506
+ out = r.process(pil_rgb, type="rgba")
507
+ if _alpha_is_all_zero(out):
508
+ return fallback_rgba
509
+ return out
510
+ except BaseException:
511
+ # Last resort: passthrough
512
+ return fallback_rgba
513
 
514
 
515
  # -----------------------------------------------------------------------------
516
+ # Nodes
517
  # -----------------------------------------------------------------------------
518
 
519
  class InspyrenetRembg2:
 
 
 
 
 
520
  def __init__(self):
521
  pass
522
 
 
525
  return {
526
  "required": {
527
  "image": ("IMAGE",),
528
+ "torchscript_jit": (["default", "on"],)
529
  },
530
  }
531
 
 
535
 
536
  def remove_background(self, image, torchscript_jit):
537
  jit = (torchscript_jit != "default")
538
+ remover, run_lock = _get_remover(jit=jit)
539
 
540
  img_list = []
541
  for img in tqdm(image, "Inspyrenet Rembg2"):
542
  pil_in = tensor2pil(img)
543
  try:
544
+ with run_lock:
545
+ with torch.inference_mode():
546
+ mid = remover.process(pil_in, type="rgba")
547
  except BaseException as e:
548
  if _is_oom_error(e):
549
  _cuda_soft_cleanup()
550
  raise RuntimeError("InspyrenetRembg2: CUDA out of memory.") from e
551
  raise
552
+
553
  out = pil2tensor(mid)
554
  img_list.append(out)
555
  del pil_in, mid, out
 
560
 
561
 
562
  class InspyrenetRembg3:
 
 
 
 
 
563
  def __init__(self):
564
  pass
565
 
 
576
  CATEGORY = "image"
577
 
578
  def remove_background(self, image):
579
+ remover, run_lock = _get_remover(jit=False)
580
 
581
  img_list = []
582
  for img in tqdm(image, "Inspyrenet Rembg3"):
 
584
  pil_rgb = _rgba_to_rgb_on_white(pil_in)
585
 
586
  try:
587
+ with run_lock:
588
+ with torch.inference_mode():
589
+ mid = remover.process(pil_rgb, type="rgba")
590
  except BaseException as e:
591
  if _is_oom_error(e):
592
  _cuda_soft_cleanup()
 
595
 
596
  out = pil2tensor(mid)
597
  img_list.append(out)
 
598
  del pil_in, pil_rgb, mid, out
599
 
600
  img_stack = torch.cat(img_list, dim=0)
 
602
 
603
 
604
  # -----------------------------------------------------------------------------
605
+ # NEW: Global nodes (simple, no user settings on Load/Run)
606
  # -----------------------------------------------------------------------------
607
 
608
  class Load_Inspyrenet_Global:
609
  """
610
+ No inputs. Creates the global remover (once) and moves it to ComfyUI's CUDA device (if possible).
611
+ Returns:
612
+ - loaded_ok (BOOLEAN)
613
+ - vram_delta_bytes (INT) best-effort (weights residency only; not peak inference)
614
  """
615
  def __init__(self):
616
  pass
617
 
618
  @classmethod
619
  def INPUT_TYPES(s):
620
+ return {"required": {}}
 
 
 
 
621
 
622
+ RETURN_TYPES = ("BOOLEAN", "INT")
 
 
 
 
 
 
 
 
 
 
 
 
623
  FUNCTION = "load"
624
  CATEGORY = "image"
625
 
626
+ def load(self):
627
+ _get_global_remover()
628
+ ok = _load_global_to_comfy_cuda_no_crash()
629
+ return (bool(ok), int(_GLOBAL_VRAM_DELTA_BYTES))
630
+
631
+
632
+ class Remove_Inspyrenet_Global:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  """
634
+ Offload global remover to CPU or delete it.
 
635
  """
636
  def __init__(self):
637
  pass
 
640
  def INPUT_TYPES(s):
641
  return {
642
  "required": {
643
+ "action": (["offload_to_cpu", "delete_instance"],),
644
+ }
 
 
 
 
 
 
645
  }
646
 
647
+ RETURN_TYPES = ("BOOLEAN",)
648
  FUNCTION = "remove"
649
  CATEGORY = "image"
650
 
651
+ def remove(self, action):
652
+ global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
653
+ if action == "offload_to_cpu":
654
+ _move_global_to_cpu()
655
+ return (True,)
 
 
 
 
 
 
 
656
 
657
+ # delete_instance
658
+ with _GLOBAL_LOCK:
659
+ try:
660
+ if _GLOBAL_REMOVER is not None:
661
+ try:
662
+ if hasattr(_GLOBAL_REMOVER, "model"):
663
+ _GLOBAL_REMOVER.model = _GLOBAL_REMOVER.model.to("cpu")
664
+ _GLOBAL_REMOVER.device = "cpu"
665
+ except Exception:
666
+ pass
667
+ _GLOBAL_REMOVER = None
668
+ _GLOBAL_ON_DEVICE = "cpu"
669
+ _GLOBAL_VRAM_DELTA_BYTES = 0
670
+ except Exception:
671
+ pass
672
 
 
673
  _cuda_soft_cleanup()
674
+ return (True,)
 
 
675
 
676
 
677
  class Run_InspyrenetRembg_Global:
678
  """
679
+ No settings. Same behavior as InspyrenetRembg3, but uses the global remover and won't crash on OOM.
680
+ On failure/OOM, returns a visible passthrough (opaque RGBA), NOT an invisible image.
 
 
 
681
  """
682
  def __init__(self):
683
  pass
 
687
  return {
688
  "required": {
689
  "image": ("IMAGE",),
690
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
  }
692
 
693
  RETURN_TYPES = ("IMAGE",)
694
  FUNCTION = "remove_background"
695
  CATEGORY = "image"
696
 
697
+ def remove_background(self, image):
698
+ _get_global_remover()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
 
700
  img_list = []
701
  for img in tqdm(image, "Run InspyrenetRembg Global"):
702
  pil_in = tensor2pil(img)
703
 
704
+ # Visible fallback (never invisible)
705
+ fallback = _force_rgba_opaque(pil_in)
706
 
707
+ # Exactly like Rembg3 input path
708
  pil_rgb = _rgba_to_rgb_on_white(pil_in)
709
 
710
+ out_pil = _run_global_rgba_no_crash(pil_rgb, fallback)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  out = pil2tensor(out_pil)
712
  img_list.append(out)
713
 
714
+ del pil_in, fallback, pil_rgb, out_pil, out
715
 
716
  img_stack = torch.cat(img_list, dim=0)
717
  return (img_stack,)
718
 
719
 
 
 
 
 
720
  NODE_CLASS_MAPPINGS = {
721
  "InspyrenetRembg2": InspyrenetRembg2,
722
  "InspyrenetRembg3": InspyrenetRembg3,
723
 
724
  "Load_Inspyrenet_Global": Load_Inspyrenet_Global,
725
+ "Remove_Inspyrenet_Global": Remove_Inspyrenet_Global,
726
  "Run_InspyrenetRembg_Global": Run_InspyrenetRembg_Global,
727
  }
728
 
 
731
  "InspyrenetRembg3": "Inspyrenet Rembg3",
732
 
733
  "Load_Inspyrenet_Global": "Load Inspyrenet Global",
734
+ "Remove_Inspyrenet_Global": "Remove Inspyrenet Global",
735
  "Run_InspyrenetRembg_Global": "Run InspyrenetRembg Global",
736
  }