Spaces:
Running on Zero
Running on Zero
feat(modes): upscale handler (realesrgan + z-image-turbo refinement)
Browse files- modes.py +35 -0
- tests/test_modes.py +39 -0
modes.py
CHANGED
|
@@ -8,6 +8,7 @@ from PIL import Image
|
|
| 8 |
|
| 9 |
import lora
|
| 10 |
import preprocessors
|
|
|
|
| 11 |
|
| 12 |
try:
|
| 13 |
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
|
@@ -111,3 +112,37 @@ def call_controlnet(pipe: Any, params: dict[str, Any]) -> tuple[Image.Image, dic
|
|
| 111 |
lora_strength=params.get("lora_strength", 0.0),
|
| 112 |
)
|
| 113 |
return image, meta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
import lora
|
| 10 |
import preprocessors
|
| 11 |
+
import upscale
|
| 12 |
|
| 13 |
try:
|
| 14 |
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
|
|
|
| 112 |
lora_strength=params.get("lora_strength", 0.0),
|
| 113 |
)
|
| 114 |
return image, meta
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def call_upscale(pipe: Any, params: dict[str, Any]) -> tuple[Image.Image, dict[str, Any]]:
|
| 118 |
+
"""Upscale — RealESRGAN x4 → 0.5 resize → Z-Image-Turbo img2img refinement."""
|
| 119 |
+
input_image: Image.Image | None = params.get("input_image")
|
| 120 |
+
if input_image is None:
|
| 121 |
+
raise ValueError("Upscale mode requires an input image")
|
| 122 |
+
|
| 123 |
+
upscaled = upscale.realesrgan_2x(input_image, model_path=params["esrgan_model_path"])
|
| 124 |
+
|
| 125 |
+
_swap_transformer(pipe, "Turbo")
|
| 126 |
+
|
| 127 |
+
kwargs: dict[str, Any] = dict(
|
| 128 |
+
prompt=params.get("prompt", "masterpiece, 8k"),
|
| 129 |
+
cfg_scale=1.0,
|
| 130 |
+
num_inference_steps=int(params.get("refine_steps", 5)),
|
| 131 |
+
sigma_shift=3.0,
|
| 132 |
+
input_image=upscaled,
|
| 133 |
+
denoising_strength=float(params.get("refine_denoise", 0.33)),
|
| 134 |
+
seed=int(params.get("seed", 0)),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
with lora.applied_lora(pipe, params.get("lora_path"), params.get("lora_strength", 0.0)):
|
| 138 |
+
image = pipe(**kwargs)
|
| 139 |
+
|
| 140 |
+
meta = dict(
|
| 141 |
+
mode="upscale", model="Turbo",
|
| 142 |
+
refine_steps=kwargs["num_inference_steps"],
|
| 143 |
+
refine_denoise=kwargs["denoising_strength"],
|
| 144 |
+
seed=kwargs["seed"], width=upscaled.size[0], height=upscaled.size[1],
|
| 145 |
+
lora=str(params.get("lora_path")) if params.get("lora_path") else None,
|
| 146 |
+
lora_strength=params.get("lora_strength", 0.0),
|
| 147 |
+
)
|
| 148 |
+
return image, meta
|
tests/test_modes.py
CHANGED
|
@@ -110,3 +110,42 @@ def test_controlnet_rejects_missing_input_image(fake_pipe):
|
|
| 110 |
controlnet_scale=1.0, steps=9, seed=0,
|
| 111 |
lora_path=None, lora_strength=0.0),
|
| 112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
controlnet_scale=1.0, steps=9, seed=0,
|
| 111 |
lora_path=None, lora_strength=0.0),
|
| 112 |
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_upscale_runs_realesrgan_then_pipeline(fake_pipe, monkeypatch):
|
| 116 |
+
calls = {"upscale": None}
|
| 117 |
+
def fake_2x(img, model_path):
|
| 118 |
+
calls["upscale"] = (img.size, str(model_path))
|
| 119 |
+
w, h = img.size
|
| 120 |
+
return img.resize((w * 2, h * 2), Image.LANCZOS)
|
| 121 |
+
monkeypatch.setattr(modes, "upscale", type("U", (), {"realesrgan_2x": staticmethod(fake_2x)}))
|
| 122 |
+
|
| 123 |
+
input_image = Image.new("RGB", (512, 512))
|
| 124 |
+
out, meta = modes.call_upscale(
|
| 125 |
+
fake_pipe,
|
| 126 |
+
params=dict(
|
| 127 |
+
prompt="masterpiece, 8k",
|
| 128 |
+
input_image=input_image,
|
| 129 |
+
refine_steps=5,
|
| 130 |
+
refine_denoise=0.33,
|
| 131 |
+
seed=42,
|
| 132 |
+
lora_path=None, lora_strength=0.0,
|
| 133 |
+
esrgan_model_path="/fake/path/RealESRGAN_x4plus.pth",
|
| 134 |
+
),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
assert calls["upscale"] == ((512, 512), "/fake/path/RealESRGAN_x4plus.pth")
|
| 138 |
+
kwargs = fake_pipe.call_args.kwargs
|
| 139 |
+
assert kwargs["input_image"].size == (1024, 1024) # 2x via fake_2x
|
| 140 |
+
assert kwargs["denoising_strength"] == 0.33
|
| 141 |
+
assert kwargs["num_inference_steps"] == 5
|
| 142 |
+
assert kwargs["cfg_scale"] == 1.0
|
| 143 |
+
assert meta["mode"] == "upscale"
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_upscale_rejects_missing_image(fake_pipe):
|
| 147 |
+
with pytest.raises(ValueError):
|
| 148 |
+
modes.call_upscale(fake_pipe, params=dict(prompt="x", input_image=None,
|
| 149 |
+
refine_steps=5, refine_denoise=0.33, seed=0,
|
| 150 |
+
lora_path=None, lora_strength=0.0,
|
| 151 |
+
esrgan_model_path="/fake.pth"))
|