z-image-studio / upscale.py
techfreakworm's picture
fix: depth preprocessor name + basicsr/torchvision import shim
6d862f4 unverified
"""RealESRGAN x4plus wrapper + 0.5-resize bridge.
This module only handles the *pixel-space* upscale. The Z-Image-Turbo refinement
pass (img2img at denoise=0.33) lives in :mod:`modes` since it shares the pipeline.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from PIL import Image
def realesrgan_2x(image: Image.Image | None, model_path: Path | str) -> Image.Image:
"""RealESRGAN x4plus → ``image.resize(0.5)`` → net 2x upscale."""
if image is None:
raise ValueError("upscale needs an input image")
upscaled = _realesrgan_4x(model_path, image)
w, h = upscaled.size
return upscaled.resize((w // 2, h // 2), Image.LANCZOS)
_MODEL_CACHE: dict[str, Any] = {}
def _realesrgan_4x(model_path: Path | str, image: Image.Image) -> Image.Image:
"""Run RealESRGAN x4plus on ``image``. Caches the model in-process."""
import sys
import numpy as np
import torchvision.transforms.functional as _tvf
# basicsr (a realesrgan dep) imports torchvision.transforms.functional_tensor,
# which was removed in torchvision >=0.17. Alias the old path to the current
# module so basicsr's degradations import keeps working.
sys.modules.setdefault("torchvision.transforms.functional_tensor", _tvf)
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
key = str(model_path)
if key not in _MODEL_CACHE:
net = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
_MODEL_CACHE[key] = RealESRGANer(
scale=4,
model_path=key,
model=net,
tile=512, # split into tiles to avoid OOM on large inputs
tile_pad=10,
pre_pad=0,
half=False, # bf16 elsewhere; keep this fp32 for stability
gpu_id=None,
)
upsampler = _MODEL_CACHE[key]
arr = np.array(image.convert("RGB"))
out_arr, _ = upsampler.enhance(arr, outscale=4)
return Image.fromarray(out_arr)