| import os |
| import gc |
| import time |
| import platform |
| import ctypes |
| from ctypes import wintypes |
| import torch |
| import torch.nn.functional as F |
| import comfy.model_management as model_management |
| import comfy.sample as _sample |
| import comfy.samplers as _samplers |
| import comfy.utils as _utils |
|
|
| try: |
| import psutil |
| except Exception: |
| psutil = None |
|
|
|
|
| def _get_ram_mb() -> float: |
| try: |
| if psutil is not None: |
| p = psutil.Process(os.getpid()) |
| rss = float(p.memory_info().rss) |
| try: |
| private = getattr(p.memory_full_info(), "private", None) |
| if isinstance(private, (int, float)) and private > 0: |
| rss = float(private) |
| except Exception: |
| pass |
| return rss / (1024.0 * 1024.0) |
| except Exception: |
| pass |
| return 0.0 |
|
|
|
|
| def _get_vram_mb_per_device() -> list[tuple[int, float, float]]: |
| out = [] |
| try: |
| if torch.cuda.is_available(): |
| for d in range(torch.cuda.device_count()): |
| try: |
| reserved = float(torch.cuda.memory_reserved(d)) / (1024.0 * 1024.0) |
| allocated = float(torch.cuda.memory_allocated(d)) / (1024.0 * 1024.0) |
| except Exception: |
| reserved = 0.0 |
| allocated = 0.0 |
| out.append((d, reserved, allocated)) |
| except Exception: |
| pass |
| return out |
|
|
|
|
| def _trim_working_set_windows(): |
| try: |
| if platform.system().lower().startswith("win"): |
| kernel32 = ctypes.windll.kernel32 |
| proc = kernel32.GetCurrentProcess() |
| kernel32.SetProcessWorkingSetSize(proc, ctypes.c_size_t(-1), ctypes.c_size_t(-1)) |
| except Exception: |
| pass |
|
|
|
|
| def _enable_win_privileges(names): |
| """Best-effort enable a set of Windows privileges for the current process.""" |
| try: |
| if not platform.system().lower().startswith('win'): |
| return False |
| advapi32 = ctypes.windll.advapi32 |
| kernel32 = ctypes.windll.kernel32 |
| token = wintypes.HANDLE() |
| TOKEN_ADJUST_PRIVILEGES = 0x20 |
| TOKEN_QUERY = 0x8 |
| if not advapi32.OpenProcessToken(kernel32.GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, ctypes.byref(token)): |
| return False |
|
|
| class LUID(ctypes.Structure): |
| _fields_ = [("LowPart", wintypes.DWORD), ("HighPart", wintypes.LONG)] |
|
|
| class LUID_AND_ATTRIBUTES(ctypes.Structure): |
| _fields_ = [("Luid", LUID), ("Attributes", wintypes.DWORD)] |
|
|
| class TOKEN_PRIVILEGES(ctypes.Structure): |
| _fields_ = [("PrivilegeCount", wintypes.DWORD), ("Privileges", LUID_AND_ATTRIBUTES * 1)] |
|
|
| SE_PRIVILEGE_ENABLED = 0x2 |
| success = False |
| for name in names: |
| luid = LUID() |
| if not advapi32.LookupPrivilegeValueW(None, ctypes.c_wchar_p(name), ctypes.byref(luid)): |
| continue |
| tp = TOKEN_PRIVILEGES() |
| tp.PrivilegeCount = 1 |
| tp.Privileges[0].Luid = luid |
| tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED |
| if advapi32.AdjustTokenPrivileges(token, False, ctypes.byref(tp), 0, None, None): |
| success = True |
| return success |
| except Exception: |
| return False |
|
|
|
|
| def _system_cache_trim_windows(): |
| """Attempt to purge standby/file caches on Windows (requires privileges).""" |
| try: |
| if not platform.system().lower().startswith('win'): |
| return False |
| _enable_win_privileges([ |
| 'SeIncreaseQuotaPrivilege', |
| 'SeProfileSingleProcessPrivilege', |
| 'SeDebugPrivilege', |
| ]) |
| try: |
| kernel32 = ctypes.windll.kernel32 |
| SIZE_T = ctypes.c_size_t |
| kernel32.SetSystemFileCacheSize(SIZE_T(-1), SIZE_T(-1), wintypes.DWORD(0)) |
| except Exception: |
| pass |
| try: |
| ntdll = ctypes.windll.ntdll |
| SystemMemoryListInformation = 0x50 |
| MemoryPurgeStandbyList = ctypes.c_ulong(4) |
| ntdll.NtSetSystemInformation(wintypes.ULONG(SystemMemoryListInformation), ctypes.byref(MemoryPurgeStandbyList), ctypes.sizeof(MemoryPurgeStandbyList)) |
| except Exception: |
| pass |
| return True |
| except Exception: |
| return False |
|
|
|
|
| def cleanup_memory(sync_cuda: bool = True, hard_trim: bool = True) -> dict: |
| """Run a best-effort cleanup of RAM/VRAM. Returns stats dict with before/after deltas.""" |
| stats: dict = {"ram_before_mb": 0.0, "ram_after_mb": 0.0, "ram_freed_mb": 0.0, "gpu": []} |
| stats["ram_before_mb"] = _get_ram_mb() |
| gpu_before = _get_vram_mb_per_device() |
| try: |
| if sync_cuda and torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| except Exception: |
| pass |
| try: |
| import comfy.model_management as mm |
| if hasattr(mm, 'soft_empty_cache'): |
| mm.soft_empty_cache() |
| except Exception: |
| pass |
| try: |
| gc.collect() |
| except Exception: |
| pass |
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.ipc_collect() |
| except Exception: |
| pass |
| try: |
| time.sleep(0) |
| except Exception: |
| pass |
| if hard_trim: |
| try: |
| import comfy.model_management as mm |
| if hasattr(mm, 'unload_all_models'): |
| mm.unload_all_models() |
| except Exception: |
| pass |
| try: |
| for _ in range(2): |
| time.sleep(0) |
| gc.collect() |
| except Exception: |
| pass |
| try: |
| if hasattr(_utils, 'cleanup_lru_caches'): |
| _utils.cleanup_lru_caches() |
| except Exception: |
| pass |
| try: |
| _trim_working_set_windows() |
| psapi = ctypes.windll.psapi |
| kernel32 = ctypes.windll.kernel32 |
| psapi.EmptyWorkingSet(kernel32.GetCurrentProcess()) |
| except Exception: |
| pass |
| try: |
| if platform.system().lower().startswith('linux'): |
| libc = ctypes.CDLL('libc.so.6') |
| libc.malloc_trim(0) |
| except Exception: |
| pass |
| try: |
| _system_cache_trim_windows() |
| except Exception: |
| pass |
| stats["ram_after_mb"] = _get_ram_mb() |
| stats["ram_freed_mb"] = max(0.0, stats["ram_before_mb"] - stats["ram_after_mb"]) |
| gpu_after = _get_vram_mb_per_device() |
| device_map = {d: (r, a) for d, r, a in gpu_before} |
| gpu_stats = [] |
| for d, r_after, a_after in gpu_after: |
| r_before, a_before = device_map.get(d, (0.0, 0.0)) |
| gpu_stats.append({ |
| "device": d, |
| "reserved_before_mb": r_before, |
| "reserved_after_mb": r_after, |
| "reserved_freed_mb": max(0.0, r_before - r_after), |
| "allocated_before_mb": a_before, |
| "allocated_after_mb": a_after, |
| "allocated_freed_mb": max(0.0, a_before - a_after), |
| }) |
| stats["gpu"] = gpu_stats |
| return stats |
|
|
|
|
| class MG_CleanUp: |
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "samples": ("LATENT", {}), |
| }, |
| "optional": { |
| "hard_trim": ("BOOLEAN", {"default": True, "tooltip": "Aggressively free RAM/VRAM and ask OS to return pages to the system."}), |
| "sync_cuda": ("BOOLEAN", {"default": True, "tooltip": "Synchronize CUDA before cleanup to flush pending kernels."}), |
| "hires_only_threshold": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 64, "tooltip": "Apply only when latent longest side >= threshold (0 == always)."}), |
| } |
| } |
|
|
| RETURN_TYPES = ("LATENT", "IMAGE") |
| RETURN_NAMES = ("samples", "Preview") |
| FUNCTION = "apply" |
| CATEGORY = "MagicNodes" |
|
|
| def apply(self, samples, hard_trim=True, sync_cuda=True, hires_only_threshold=0, |
| model=None, positive=None, negative=None, vae=None): |
| img_prev = None |
| try: |
| if (model is not None) and (positive is not None) and (negative is not None) and (vae is not None): |
| lat = samples.get("samples", None) |
| if lat is not None and isinstance(lat, torch.Tensor) and lat.ndim == 4: |
| z = lat |
| B, C, H, W = z.shape |
| target = 32 |
| z_ds = z if (H == target and W == target) else F.interpolate(z, size=(target, target), mode='bilinear', align_corners=False) |
| lat_img = _sample.fix_empty_latent_channels(model, z_ds) if hasattr(_sample, 'fix_empty_latent_channels') else z_ds |
| batch_inds = samples.get("batch_index", None) |
| noise = _sample.prepare_noise(lat_img, int(0), batch_inds) |
| steps = 1 |
| out = _sample.sample( |
| model, noise, int(steps), float(1.0), "ddim", "normal", |
| positive, negative, lat_img, |
| denoise=float(0.10), disable_noise=False, start_step=None, last_step=None, |
| force_full_denoise=False, noise_mask=None, callback=None, |
| disable_pbar=not _utils.PROGRESS_BAR_ENABLED, seed=int(0) |
| ) |
| try: |
| img_prev = vae.decode(out) |
| if len(img_prev.shape) == 5: |
| img_prev = img_prev.reshape(-1, img_prev.shape[-3], img_prev.shape[-2], img_prev.shape[-1]) |
| except Exception: |
| img_prev = None |
| except Exception: |
| img_prev = None |
|
|
| try: |
| do_cleanup = True |
| try: |
| if int(hires_only_threshold) > 0: |
| z = samples.get("samples", None) |
| if z is not None and hasattr(z, "shape") and len(z.shape) >= 4: |
| _, _, H, W = z.shape |
| if max(int(H), int(W)) < int(hires_only_threshold): |
| do_cleanup = False |
| except Exception: |
| pass |
| if do_cleanup: |
| print("=== CleanUP RAM and GPU ===") |
| stats = cleanup_memory(sync_cuda=bool(sync_cuda), hard_trim=bool(hard_trim)) |
| try: |
| print(f"RAM freed: {stats['ram_freed_mb']:.1f} MB (before {stats['ram_before_mb']:.1f} -> after {stats['ram_after_mb']:.1f})") |
| except Exception: |
| pass |
| try: |
| for g in stats.get("gpu", []): |
| print( |
| f"GPU{g['device']}: reserved freed {g['reserved_freed_mb']:.1f} MB ( {g['reserved_before_mb']:.1f} -> {g['reserved_after_mb']:.1f} ), " |
| f"allocated freed {g['allocated_freed_mb']:.1f} MB ( {g['allocated_before_mb']:.1f} -> {g['allocated_after_mb']:.1f} )" |
| ) |
| except Exception: |
| pass |
| |
| try: |
| time.sleep(0.150) |
| stats2 = cleanup_memory(sync_cuda=False, hard_trim=bool(hard_trim)) |
| if stats2 and float(stats2.get('ram_freed_mb', 0.0)) > 0.0: |
| print(f"2nd pass: RAM freed +{stats2['ram_freed_mb']:.1f} MB") |
| try: |
| for g in stats2.get('gpu', []): |
| if float(g.get('reserved_freed_mb', 0.0)) > 0.0 or float(g.get('allocated_freed_mb', 0.0)) > 0.0: |
| print(f"2nd pass GPU{g['device']}: reserved +{g['reserved_freed_mb']:.1f} MB, allocated +{g['allocated_freed_mb']:.1f} MB") |
| except Exception: |
| pass |
| except Exception: |
| pass |
| print("done.") |
| except Exception: |
| pass |
|
|
| if img_prev is None: |
| try: |
| device = model_management.intermediate_device() if hasattr(model_management, 'intermediate_device') else 'cpu' |
| img_prev = torch.zeros((1, 32, 32, 3), dtype=torch.float32, device=device) |
| except Exception: |
| img_prev = torch.zeros((1, 32, 32, 3)) |
| return (samples, img_prev) |
|
|
|
|
|
|
|
|