techfreakworm commited on
Commit
b855333
·
unverified ·
1 Parent(s): 036940b

feat(upscale): realesrgan x4 wrapper with 0.5-resize bridge

Browse files
Files changed (2) hide show
  1. tests/test_upscale.py +26 -0
  2. upscale.py +49 -0
tests/test_upscale.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest import mock
2
+ import pytest
3
+ from PIL import Image
4
+
5
+ import upscale
6
+
7
+
8
+ @pytest.fixture
9
+ def small_image():
10
+ return Image.new("RGB", (256, 256), color=(120, 50, 200))
11
+
12
+
13
+ def test_realesrgan_2x_produces_2x_image(small_image, monkeypatch):
14
+ """RealESRGAN runs 4x then we scale down 0.5 → net 2x."""
15
+ def fake_run_4x(_model_path, image):
16
+ w, h = image.size
17
+ return image.resize((w * 4, h * 4), Image.LANCZOS)
18
+ monkeypatch.setattr(upscale, "_realesrgan_4x", fake_run_4x)
19
+
20
+ out = upscale.realesrgan_2x(small_image, model_path="/dev/null")
21
+ assert out.size == (512, 512)
22
+
23
+
24
+ def test_realesrgan_2x_rejects_none():
25
+ with pytest.raises(ValueError):
26
+ upscale.realesrgan_2x(None, model_path="/dev/null")
upscale.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RealESRGAN x4plus wrapper + 0.5-resize bridge.
2
+
3
+ This module only handles the *pixel-space* upscale. The Z-Image-Turbo refinement
4
+ pass (img2img at denoise=0.33) lives in :mod:`modes` since it shares the pipeline.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from PIL import Image
12
+
13
+
14
+ def realesrgan_2x(image: Image.Image | None, model_path: Path | str) -> Image.Image:
15
+ """RealESRGAN x4plus → ``image.resize(0.5)`` → net 2x upscale."""
16
+ if image is None:
17
+ raise ValueError("upscale needs an input image")
18
+ upscaled = _realesrgan_4x(model_path, image)
19
+ w, h = upscaled.size
20
+ return upscaled.resize((w // 2, h // 2), Image.LANCZOS)
21
+
22
+
23
+ _MODEL_CACHE: dict[str, Any] = {}
24
+
25
+
26
+ def _realesrgan_4x(model_path: Path | str, image: Image.Image) -> Image.Image:
27
+ """Run RealESRGAN x4plus on ``image``. Caches the model in-process."""
28
+ import numpy as np
29
+ from realesrgan import RealESRGANer
30
+ from basicsr.archs.rrdbnet_arch import RRDBNet
31
+
32
+ key = str(model_path)
33
+ if key not in _MODEL_CACHE:
34
+ net = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
35
+ _MODEL_CACHE[key] = RealESRGANer(
36
+ scale=4,
37
+ model_path=key,
38
+ model=net,
39
+ tile=512, # split into tiles to avoid OOM on large inputs
40
+ tile_pad=10,
41
+ pre_pad=0,
42
+ half=False, # bf16 elsewhere; keep this fp32 for stability
43
+ gpu_id=None,
44
+ )
45
+
46
+ upsampler = _MODEL_CACHE[key]
47
+ arr = np.array(image.convert("RGB"))
48
+ out_arr, _ = upsampler.enhance(arr, outscale=4)
49
+ return Image.fromarray(out_arr)