techfreakworm commited on
Commit
84d00fe
·
unverified ·
1 Parent(s): 3762756

feat(modes): upscale handler (realesrgan + z-image-turbo refinement)

Browse files
Files changed (2) hide show
  1. modes.py +35 -0
  2. tests/test_modes.py +39 -0
modes.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
 
9
  import lora
10
  import preprocessors
 
11
 
12
  try:
13
  from diffsynth.diffusion.base_pipeline import ControlNetInput
@@ -111,3 +112,37 @@ def call_controlnet(pipe: Any, params: dict[str, Any]) -> tuple[Image.Image, dic
111
  lora_strength=params.get("lora_strength", 0.0),
112
  )
113
  return image, meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  import lora
10
  import preprocessors
11
+ import upscale
12
 
13
  try:
14
  from diffsynth.diffusion.base_pipeline import ControlNetInput
 
112
  lora_strength=params.get("lora_strength", 0.0),
113
  )
114
  return image, meta
115
+
116
+
117
+ def call_upscale(pipe: Any, params: dict[str, Any]) -> tuple[Image.Image, dict[str, Any]]:
118
+ """Upscale — RealESRGAN x4 → 0.5 resize → Z-Image-Turbo img2img refinement."""
119
+ input_image: Image.Image | None = params.get("input_image")
120
+ if input_image is None:
121
+ raise ValueError("Upscale mode requires an input image")
122
+
123
+ upscaled = upscale.realesrgan_2x(input_image, model_path=params["esrgan_model_path"])
124
+
125
+ _swap_transformer(pipe, "Turbo")
126
+
127
+ kwargs: dict[str, Any] = dict(
128
+ prompt=params.get("prompt", "masterpiece, 8k"),
129
+ cfg_scale=1.0,
130
+ num_inference_steps=int(params.get("refine_steps", 5)),
131
+ sigma_shift=3.0,
132
+ input_image=upscaled,
133
+ denoising_strength=float(params.get("refine_denoise", 0.33)),
134
+ seed=int(params.get("seed", 0)),
135
+ )
136
+
137
+ with lora.applied_lora(pipe, params.get("lora_path"), params.get("lora_strength", 0.0)):
138
+ image = pipe(**kwargs)
139
+
140
+ meta = dict(
141
+ mode="upscale", model="Turbo",
142
+ refine_steps=kwargs["num_inference_steps"],
143
+ refine_denoise=kwargs["denoising_strength"],
144
+ seed=kwargs["seed"], width=upscaled.size[0], height=upscaled.size[1],
145
+ lora=str(params.get("lora_path")) if params.get("lora_path") else None,
146
+ lora_strength=params.get("lora_strength", 0.0),
147
+ )
148
+ return image, meta
tests/test_modes.py CHANGED
@@ -110,3 +110,42 @@ def test_controlnet_rejects_missing_input_image(fake_pipe):
110
  controlnet_scale=1.0, steps=9, seed=0,
111
  lora_path=None, lora_strength=0.0),
112
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  controlnet_scale=1.0, steps=9, seed=0,
111
  lora_path=None, lora_strength=0.0),
112
  )
113
+
114
+
115
+ def test_upscale_runs_realesrgan_then_pipeline(fake_pipe, monkeypatch):
116
+ calls = {"upscale": None}
117
+ def fake_2x(img, model_path):
118
+ calls["upscale"] = (img.size, str(model_path))
119
+ w, h = img.size
120
+ return img.resize((w * 2, h * 2), Image.LANCZOS)
121
+ monkeypatch.setattr(modes, "upscale", type("U", (), {"realesrgan_2x": staticmethod(fake_2x)}))
122
+
123
+ input_image = Image.new("RGB", (512, 512))
124
+ out, meta = modes.call_upscale(
125
+ fake_pipe,
126
+ params=dict(
127
+ prompt="masterpiece, 8k",
128
+ input_image=input_image,
129
+ refine_steps=5,
130
+ refine_denoise=0.33,
131
+ seed=42,
132
+ lora_path=None, lora_strength=0.0,
133
+ esrgan_model_path="/fake/path/RealESRGAN_x4plus.pth",
134
+ ),
135
+ )
136
+
137
+ assert calls["upscale"] == ((512, 512), "/fake/path/RealESRGAN_x4plus.pth")
138
+ kwargs = fake_pipe.call_args.kwargs
139
+ assert kwargs["input_image"].size == (1024, 1024) # 2x via fake_2x
140
+ assert kwargs["denoising_strength"] == 0.33
141
+ assert kwargs["num_inference_steps"] == 5
142
+ assert kwargs["cfg_scale"] == 1.0
143
+ assert meta["mode"] == "upscale"
144
+
145
+
146
+ def test_upscale_rejects_missing_image(fake_pipe):
147
+ with pytest.raises(ValueError):
148
+ modes.call_upscale(fake_pipe, params=dict(prompt="x", input_image=None,
149
+ refine_steps=5, refine_denoise=0.33, seed=0,
150
+ lora_path=None, lora_strength=0.0,
151
+ esrgan_model_path="/fake.pth"))