| from __future__ import annotations
|
|
|
| from PIL import Image
|
| import os
|
| import urllib.request
|
| import gc
|
| import threading
|
| from typing import Dict, Tuple, Optional
|
|
|
| import torch
|
| import numpy as np
|
| from transparent_background import Remover
|
| from tqdm import tqdm
|
|
|
|
|
|
|
| try:
|
| import comfy.model_management as comfy_mm
|
| except Exception:
|
| comfy_mm = None
|
|
|
|
|
| CKPT_PATH = "/root/.transparent-background/ckpt_base.pth"
|
| CKPT_URL = "https://huggingface.co/saliacoel/x/resolve/main/ckpt_base.pth"
|
|
|
|
|
| def _ensure_ckpt_base():
|
| try:
|
| if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
|
| return
|
| except Exception:
|
| pass
|
|
|
| os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True)
|
| tmp_path = CKPT_PATH + ".tmp"
|
|
|
| try:
|
| with urllib.request.urlopen(CKPT_URL) as resp:
|
| total = resp.headers.get("Content-Length")
|
| total = int(total) if total is not None else None
|
|
|
| with open(tmp_path, "wb") as f:
|
| if total:
|
| with tqdm(
|
| total=total,
|
| unit="B",
|
| unit_scale=True,
|
| desc="Downloading ckpt_base.pth",
|
| ) as pbar:
|
| while True:
|
| chunk = resp.read(1024 * 1024)
|
| if not chunk:
|
| break
|
| f.write(chunk)
|
| pbar.update(len(chunk))
|
| else:
|
| while True:
|
| chunk = resp.read(1024 * 1024)
|
| if not chunk:
|
| break
|
| f.write(chunk)
|
|
|
| os.replace(tmp_path, CKPT_PATH)
|
| finally:
|
| if os.path.isfile(tmp_path):
|
| try:
|
| os.remove(tmp_path)
|
| except Exception:
|
| pass
|
|
|
|
|
|
|
| def tensor2pil(image: torch.Tensor) -> Image.Image:
|
| arr = image.detach().cpu().numpy()
|
| if arr.ndim == 4 and arr.shape[0] == 1:
|
| arr = arr[0]
|
| arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8)
|
| return Image.fromarray(arr)
|
|
|
|
|
|
|
| def pil2tensor(image: Image.Image) -> torch.Tensor:
|
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
|
| def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
|
| if pil_img.mode == "RGBA":
|
| bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
|
| composited = Image.alpha_composite(bg, pil_img)
|
| return composited.convert("RGB")
|
|
|
| if pil_img.mode != "RGB":
|
| return pil_img.convert("RGB")
|
|
|
| return pil_img
|
|
|
|
|
| def _force_rgba_opaque(pil_img: Image.Image) -> Image.Image:
|
| """
|
| Opaque RGBA fallback (alpha=255), so you never get an "invisible" output.
|
| """
|
| rgba = pil_img.convert("RGBA")
|
| r, g, b, _a = rgba.split()
|
| a = Image.new("L", rgba.size, 255)
|
| return Image.merge("RGBA", (r, g, b, a))
|
|
|
|
|
| def _alpha_is_all_zero(pil_img: Image.Image) -> bool:
|
| """
|
| True if RGBA image alpha channel is entirely 0.
|
| """
|
| if pil_img.mode != "RGBA":
|
| return False
|
| try:
|
| extrema = pil_img.getextrema()
|
| return extrema[3][1] == 0
|
| except Exception:
|
| return False
|
|
|
|
|
| def _is_oom_error(e: BaseException) -> bool:
|
| oom_cuda_cls = getattr(getattr(torch, "cuda", None), "OutOfMemoryError", None)
|
| if oom_cuda_cls is not None and isinstance(e, oom_cuda_cls):
|
| return True
|
|
|
| oom_torch_cls = getattr(torch, "OutOfMemoryError", None)
|
| if oom_torch_cls is not None and isinstance(e, oom_torch_cls):
|
| return True
|
|
|
| msg = str(e).lower()
|
| if "out of memory" in msg:
|
| return True
|
| if "allocation on device" in msg:
|
| return True
|
| return ("cuda" in msg or "cublas" in msg or "hip" in msg) and ("memory" in msg)
|
|
|
|
|
| def _cuda_soft_cleanup() -> None:
|
| try:
|
| gc.collect()
|
| except Exception:
|
| pass
|
|
|
| if torch.cuda.is_available():
|
| try:
|
| torch.cuda.synchronize()
|
| except Exception:
|
| pass
|
| try:
|
| torch.cuda.empty_cache()
|
| except Exception:
|
| pass
|
| try:
|
| torch.cuda.ipc_collect()
|
| except Exception:
|
| pass
|
|
|
|
|
| def _comfy_soft_empty_cache() -> None:
|
| if comfy_mm is None:
|
| return
|
| if hasattr(comfy_mm, "soft_empty_cache"):
|
| try:
|
| comfy_mm.soft_empty_cache(force=True)
|
| except TypeError:
|
| try:
|
| comfy_mm.soft_empty_cache()
|
| except Exception:
|
| pass
|
| except Exception:
|
| pass
|
|
|
|
|
| def _get_comfy_torch_device() -> torch.device:
|
| """
|
| Always prefer ComfyUI's chosen device.
|
| """
|
| if comfy_mm is not None and hasattr(comfy_mm, "get_torch_device"):
|
| try:
|
| d = comfy_mm.get_torch_device()
|
| if isinstance(d, torch.device):
|
| return d
|
| return torch.device(str(d))
|
| except Exception:
|
| pass
|
|
|
| if torch.cuda.is_available():
|
| return torch.device("cuda:0")
|
| return torch.device("cpu")
|
|
|
|
|
| def _set_current_cuda_device(dev: torch.device) -> None:
|
| """
|
| Make sure mem_get_info() measurements are on the same device ComfyUI uses.
|
| """
|
| if dev.type == "cuda":
|
| try:
|
| if dev.index is not None:
|
| torch.cuda.set_device(dev.index)
|
| except Exception:
|
| pass
|
|
|
|
|
| def _cuda_free_bytes_on(dev: torch.device) -> Optional[int]:
|
| if dev.type != "cuda" or not torch.cuda.is_available():
|
| return None
|
| try:
|
| _set_current_cuda_device(dev)
|
| free_b, _total_b = torch.cuda.mem_get_info()
|
| return int(free_b)
|
| except Exception:
|
| return None
|
|
|
|
|
| def _comfy_unload_one_smallest_model() -> bool:
|
| """
|
| Best-effort "smallest-first" eviction of one ComfyUI-tracked loaded model.
|
|
|
| If ComfyUI internals differ, this may do nothing (and we fall back to unload_all_models()).
|
| """
|
| if comfy_mm is None:
|
| return False
|
| if not hasattr(comfy_mm, "current_loaded_models"):
|
| return False
|
|
|
| try:
|
| cur_dev = _get_comfy_torch_device()
|
| except Exception:
|
| cur_dev = None
|
|
|
| models = []
|
| try:
|
| for lm in list(comfy_mm.current_loaded_models):
|
| try:
|
|
|
| lm_dev = getattr(lm, "device", None)
|
| if cur_dev is not None and lm_dev is not None and str(lm_dev) != str(cur_dev):
|
| continue
|
|
|
| mem_fn = getattr(lm, "model_loaded_memory", None)
|
| if callable(mem_fn):
|
| mem = int(mem_fn())
|
| else:
|
| mem = int(getattr(lm, "loaded_memory", 0) or 0)
|
|
|
| if mem > 0:
|
| models.append((mem, lm))
|
| except Exception:
|
| continue
|
| except Exception:
|
| return False
|
|
|
| if not models:
|
| return False
|
|
|
| models.sort(key=lambda x: x[0])
|
| _mem, lm = models[0]
|
|
|
| try:
|
| unload_fn = getattr(lm, "model_unload", None)
|
| if callable(unload_fn):
|
| try:
|
| unload_fn(unpatch_weights=True)
|
| except TypeError:
|
| unload_fn()
|
| except Exception:
|
| pass
|
|
|
|
|
| try:
|
| cleanup = getattr(comfy_mm, "cleanup_models", None)
|
| if callable(cleanup):
|
| cleanup()
|
| except Exception:
|
| pass
|
|
|
| _comfy_soft_empty_cache()
|
| _cuda_soft_cleanup()
|
| return True
|
|
|
|
|
| def _comfy_unload_all_models() -> None:
|
| if comfy_mm is None:
|
| return
|
| if hasattr(comfy_mm, "unload_all_models"):
|
| try:
|
| comfy_mm.unload_all_models()
|
| except Exception:
|
| pass
|
| _comfy_soft_empty_cache()
|
| _cuda_soft_cleanup()
|
|
|
|
|
|
|
|
|
|
|
|
|
| _REMOVER_CACHE: Dict[Tuple[bool], Remover] = {}
|
| _REMOVER_RUN_LOCKS: Dict[Tuple[bool], threading.Lock] = {}
|
| _CACHE_LOCK = threading.Lock()
|
|
|
|
|
| def _get_remover(jit: bool = False) -> tuple[Remover, threading.Lock]:
|
| key = (jit,)
|
| with _CACHE_LOCK:
|
| inst = _REMOVER_CACHE.get(key)
|
| if inst is None:
|
| _ensure_ckpt_base()
|
| try:
|
| inst = Remover(jit=jit) if jit else Remover()
|
| except BaseException as e:
|
| if _is_oom_error(e):
|
| _cuda_soft_cleanup()
|
| raise
|
| _REMOVER_CACHE[key] = inst
|
|
|
| run_lock = _REMOVER_RUN_LOCKS.get(key)
|
| if run_lock is None:
|
| run_lock = threading.Lock()
|
| _REMOVER_RUN_LOCKS[key] = run_lock
|
|
|
| return inst, run_lock
|
|
|
|
|
|
|
|
|
|
|
|
|
| _GLOBAL_LOCK = threading.Lock()
|
| _GLOBAL_RUN_LOCK = threading.Lock()
|
| _GLOBAL_REMOVER: Optional[Remover] = None
|
| _GLOBAL_ON_DEVICE: str = "cpu"
|
| _GLOBAL_VRAM_DELTA_BYTES: int = 0
|
|
|
|
|
| def _create_global_remover_cpu() -> Remover:
|
| """
|
| Create the Remover configured like InspyrenetRembg3 (jit=False),
|
| but *try* to force CPU init to avoid VRAM OOM during creation.
|
| """
|
| _ensure_ckpt_base()
|
|
|
|
|
| try:
|
| r = Remover(device="cpu")
|
| try:
|
| r.device = "cpu"
|
| except Exception:
|
| pass
|
| return r
|
| except TypeError:
|
| pass
|
|
|
|
|
| r = Remover()
|
| try:
|
| if hasattr(r, "model"):
|
| r.model = r.model.to("cpu")
|
| r.device = "cpu"
|
| except Exception:
|
| pass
|
| _cuda_soft_cleanup()
|
| return r
|
|
|
|
|
| def _get_global_remover() -> Remover:
|
| global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE
|
| with _GLOBAL_LOCK:
|
| if _GLOBAL_REMOVER is None:
|
| _GLOBAL_REMOVER = _create_global_remover_cpu()
|
| _GLOBAL_ON_DEVICE = str(getattr(_GLOBAL_REMOVER, "device", "cpu"))
|
| return _GLOBAL_REMOVER
|
|
|
|
|
| def _move_global_to_cpu() -> None:
|
| global _GLOBAL_ON_DEVICE
|
| r = _get_global_remover()
|
| try:
|
| if hasattr(r, "model"):
|
| r.model = r.model.to("cpu")
|
| r.device = "cpu"
|
| _GLOBAL_ON_DEVICE = "cpu"
|
| except Exception:
|
| pass
|
| _cuda_soft_cleanup()
|
|
|
|
|
| def _load_global_to_comfy_cuda_no_crash(max_evictions: int = 32) -> bool:
|
| """
|
| Load the global remover into VRAM on ComfyUI's chosen CUDA device.
|
| Never crashes on OOM: evicts smallest model first, then unload_all as last resort.
|
| Also records a best-effort VRAM delta.
|
| """
|
| global _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
|
|
|
| r = _get_global_remover()
|
| dev = _get_comfy_torch_device()
|
|
|
| if dev.type != "cuda" or not torch.cuda.is_available():
|
| _move_global_to_cpu()
|
| return False
|
|
|
|
|
| cur_dev = str(getattr(r, "device", "") or "")
|
| if cur_dev.startswith("cuda"):
|
| _GLOBAL_ON_DEVICE = cur_dev
|
| return True
|
|
|
| _set_current_cuda_device(dev)
|
|
|
| free_before = _cuda_free_bytes_on(dev)
|
|
|
| for _ in range(max_evictions + 1):
|
| try:
|
|
|
| if hasattr(r, "model"):
|
| r.model = r.model.to(dev)
|
| r.device = str(dev)
|
| _GLOBAL_ON_DEVICE = str(dev)
|
|
|
| _comfy_soft_empty_cache()
|
| _cuda_soft_cleanup()
|
|
|
| free_after = _cuda_free_bytes_on(dev)
|
| if free_before is not None and free_after is not None:
|
| delta = max(0, int(free_before) - int(free_after))
|
| if delta > 0:
|
| _GLOBAL_VRAM_DELTA_BYTES = delta
|
|
|
| return True
|
|
|
| except BaseException as e:
|
| if not _is_oom_error(e):
|
| raise
|
| _comfy_soft_empty_cache()
|
| _cuda_soft_cleanup()
|
|
|
|
|
| if not _comfy_unload_one_smallest_model():
|
| _comfy_unload_all_models()
|
|
|
|
|
| _move_global_to_cpu()
|
| return False
|
|
|
|
|
| def _run_global_rgba_no_crash(pil_rgb: Image.Image, fallback_rgba: Image.Image) -> Image.Image:
|
| """
|
| Run remover.process() (rgba output), matching InspyrenetRembg3 behavior.
|
| On OOM: evict models and retry, then CPU fallback.
|
| If output alpha is fully transparent, return fallback (prevents "invisible" output).
|
| """
|
| r = _get_global_remover()
|
|
|
|
|
| _load_global_to_comfy_cuda_no_crash()
|
|
|
|
|
| try:
|
| with _GLOBAL_RUN_LOCK:
|
| with torch.inference_mode():
|
| out = r.process(pil_rgb, type="rgba")
|
| if _alpha_is_all_zero(out):
|
|
|
| return fallback_rgba
|
| return out
|
| except BaseException as e:
|
| if not _is_oom_error(e):
|
| raise
|
|
|
|
|
| _comfy_soft_empty_cache()
|
| _cuda_soft_cleanup()
|
| _comfy_unload_one_smallest_model()
|
|
|
| try:
|
| with _GLOBAL_RUN_LOCK:
|
| with torch.inference_mode():
|
| out = r.process(pil_rgb, type="rgba")
|
| if _alpha_is_all_zero(out):
|
| return fallback_rgba
|
| return out
|
| except BaseException as e:
|
| if not _is_oom_error(e):
|
| raise
|
|
|
|
|
| _comfy_unload_all_models()
|
|
|
| try:
|
| with _GLOBAL_RUN_LOCK:
|
| with torch.inference_mode():
|
| out = r.process(pil_rgb, type="rgba")
|
| if _alpha_is_all_zero(out):
|
| return fallback_rgba
|
| return out
|
| except BaseException as e:
|
| if not _is_oom_error(e):
|
| raise
|
|
|
|
|
| _move_global_to_cpu()
|
| try:
|
| with _GLOBAL_RUN_LOCK:
|
| with torch.inference_mode():
|
| out = r.process(pil_rgb, type="rgba")
|
| if _alpha_is_all_zero(out):
|
| return fallback_rgba
|
| return out
|
| except BaseException:
|
|
|
| return fallback_rgba
|
|
|
|
|
|
|
|
|
|
|
|
|
| class InspyrenetRembg2:
|
| def __init__(self):
|
| pass
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "image": ("IMAGE",),
|
| "torchscript_jit": (["default", "on"],)
|
| },
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE", "MASK")
|
| FUNCTION = "remove_background"
|
| CATEGORY = "image"
|
|
|
| def remove_background(self, image, torchscript_jit):
|
| jit = (torchscript_jit != "default")
|
| remover, run_lock = _get_remover(jit=jit)
|
|
|
| img_list = []
|
| for img in tqdm(image, "Inspyrenet Rembg2"):
|
| pil_in = tensor2pil(img)
|
| try:
|
| with run_lock:
|
| with torch.inference_mode():
|
| mid = remover.process(pil_in, type="rgba")
|
| except BaseException as e:
|
| if _is_oom_error(e):
|
| _cuda_soft_cleanup()
|
| raise RuntimeError("InspyrenetRembg2: CUDA out of memory.") from e
|
| raise
|
|
|
| out = pil2tensor(mid)
|
| img_list.append(out)
|
| del pil_in, mid, out
|
|
|
| img_stack = torch.cat(img_list, dim=0)
|
| mask = img_stack[:, :, :, 3]
|
| return (img_stack, mask)
|
|
|
|
|
| class InspyrenetRembg3:
|
| def __init__(self):
|
| pass
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "image": ("IMAGE",),
|
| },
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE",)
|
| FUNCTION = "remove_background"
|
| CATEGORY = "image"
|
|
|
| def remove_background(self, image):
|
| remover, run_lock = _get_remover(jit=False)
|
|
|
| img_list = []
|
| for img in tqdm(image, "Inspyrenet Rembg3"):
|
| pil_in = tensor2pil(img)
|
| pil_rgb = _rgba_to_rgb_on_white(pil_in)
|
|
|
| try:
|
| with run_lock:
|
| with torch.inference_mode():
|
| mid = remover.process(pil_rgb, type="rgba")
|
| except BaseException as e:
|
| if _is_oom_error(e):
|
| _cuda_soft_cleanup()
|
| raise RuntimeError("InspyrenetRembg3: CUDA out of memory.") from e
|
| raise
|
|
|
| out = pil2tensor(mid)
|
| img_list.append(out)
|
| del pil_in, pil_rgb, mid, out
|
|
|
| img_stack = torch.cat(img_list, dim=0)
|
| return (img_stack,)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class Load_Inspyrenet_Global:
|
| """
|
| No inputs. Creates the global remover (once) and moves it to ComfyUI's CUDA device (if possible).
|
| Returns:
|
| - loaded_ok (BOOLEAN)
|
| - vram_delta_bytes (INT) best-effort (weights residency only; not peak inference)
|
| """
|
| def __init__(self):
|
| pass
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": {}}
|
|
|
| RETURN_TYPES = ("BOOLEAN", "INT")
|
| FUNCTION = "load"
|
| CATEGORY = "image"
|
|
|
| def load(self):
|
| _get_global_remover()
|
| ok = _load_global_to_comfy_cuda_no_crash()
|
| return (bool(ok), int(_GLOBAL_VRAM_DELTA_BYTES))
|
|
|
|
|
| class Remove_Inspyrenet_Global:
|
| """
|
| Offload global remover to CPU or delete it.
|
| """
|
| def __init__(self):
|
| pass
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "action": (["offload_to_cpu", "delete_instance"],),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("BOOLEAN",)
|
| FUNCTION = "remove"
|
| CATEGORY = "image"
|
|
|
| def remove(self, action):
|
| global _GLOBAL_REMOVER, _GLOBAL_ON_DEVICE, _GLOBAL_VRAM_DELTA_BYTES
|
| if action == "offload_to_cpu":
|
| _move_global_to_cpu()
|
| return (True,)
|
|
|
|
|
| with _GLOBAL_LOCK:
|
| try:
|
| if _GLOBAL_REMOVER is not None:
|
| try:
|
| if hasattr(_GLOBAL_REMOVER, "model"):
|
| _GLOBAL_REMOVER.model = _GLOBAL_REMOVER.model.to("cpu")
|
| _GLOBAL_REMOVER.device = "cpu"
|
| except Exception:
|
| pass
|
| _GLOBAL_REMOVER = None
|
| _GLOBAL_ON_DEVICE = "cpu"
|
| _GLOBAL_VRAM_DELTA_BYTES = 0
|
| except Exception:
|
| pass
|
|
|
| _cuda_soft_cleanup()
|
| return (True,)
|
|
|
|
|
| class Run_InspyrenetRembg_Global:
|
| """
|
| No settings. Same behavior as InspyrenetRembg3, but uses the global remover and won't crash on OOM.
|
| On failure/OOM, returns a visible passthrough (opaque RGBA), NOT an invisible image.
|
| """
|
| def __init__(self):
|
| pass
|
|
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {
|
| "required": {
|
| "image": ("IMAGE",),
|
| }
|
| }
|
|
|
| RETURN_TYPES = ("IMAGE",)
|
| FUNCTION = "remove_background"
|
| CATEGORY = "image"
|
|
|
| def remove_background(self, image):
|
| _get_global_remover()
|
|
|
| img_list = []
|
| for img in tqdm(image, "Run InspyrenetRembg Global"):
|
| pil_in = tensor2pil(img)
|
|
|
|
|
| fallback = _force_rgba_opaque(pil_in)
|
|
|
|
|
| pil_rgb = _rgba_to_rgb_on_white(pil_in)
|
|
|
| out_pil = _run_global_rgba_no_crash(pil_rgb, fallback)
|
| out = pil2tensor(out_pil)
|
| img_list.append(out)
|
|
|
| del pil_in, fallback, pil_rgb, out_pil, out
|
|
|
| img_stack = torch.cat(img_list, dim=0)
|
| return (img_stack,)
|
|
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "InspyrenetRembg2": InspyrenetRembg2,
|
| "InspyrenetRembg3": InspyrenetRembg3,
|
|
|
| "Load_Inspyrenet_Global": Load_Inspyrenet_Global,
|
| "Remove_Inspyrenet_Global": Remove_Inspyrenet_Global,
|
| "Run_InspyrenetRembg_Global": Run_InspyrenetRembg_Global,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "InspyrenetRembg2": "Inspyrenet Rembg2",
|
| "InspyrenetRembg3": "Inspyrenet Rembg3",
|
|
|
| "Load_Inspyrenet_Global": "Load Inspyrenet Global",
|
| "Remove_Inspyrenet_Global": "Remove Inspyrenet Global",
|
| "Run_InspyrenetRembg_Global": "Run InspyrenetRembg Global",
|
| } |