| import comfy.utils |
| import torch |
| import gc |
| import logging |
| import comfy.model_management as model_management |
|
|
|
|
| def clear_gpu_and_ram_cache(): |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.ipc_collect() |
|
|
|
|
| def _smart_decode(vae, latent, tile_size=512): |
| try: |
| images = vae.decode(latent["samples"]) |
| except model_management.OOM_EXCEPTION: |
| logging.warning("VAE decode OOM, using tiled decode") |
| compression = vae.spacial_compression_decode() |
| images = vae.decode_tiled( |
| latent["samples"], |
| tile_x=tile_size // compression, |
| tile_y=tile_size // compression, |
| overlap=(tile_size // 4) // compression, |
| ) |
| if len(images.shape) == 5: |
| images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) |
| return images |
|
|
|
|
| class MagicUpscaleModule: |
| """Moved into mod/ as mg_upscale_module keeping class/key name.""" |
| upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "samples": ("LATENT", {}), |
| "vae": ("VAE", {}), |
| "upscale_method": (cls.upscale_methods, {"default": "bilinear"}), |
| "scale_by": ("FLOAT", {"default": 1.2, "min": 0.01, "max": 8.0, "step": 0.01}), |
| } |
| } |
|
|
| RETURN_TYPES = ("LATENT", "IMAGE") |
| RETURN_NAMES = ("LATENT", "Upscaled Image") |
| FUNCTION = "process_upscale" |
| CATEGORY = "MagicNodes" |
|
|
| def process_upscale(self, samples, vae, upscale_method, scale_by): |
| clear_gpu_and_ram_cache() |
| images = _smart_decode(vae, samples) |
| samples_t = images.movedim(-1, 1) |
| width = round(samples_t.shape[3] * scale_by) |
| height = round(samples_t.shape[2] * scale_by) |
| |
| try: |
| stride = int(vae.spacial_compression_decode()) |
| except Exception: |
| stride = 8 |
| if stride <= 0: |
| stride = 8 |
| def _align_up(x, s): |
| return int(((x + s - 1) // s) * s) |
| width_al = _align_up(width, stride) |
| height_al = _align_up(height, stride) |
| up = comfy.utils.common_upscale(samples_t, width_al, height_al, upscale_method, "disabled") |
| up = up.movedim(1, -1) |
| encoded = vae.encode(up[:, :, :, :3]) |
| return ({"samples": encoded}, up) |
|
|