z-image-studio / tests /test_modes.py
techfreakworm's picture
refactor(upscale): drop LoRA support per spec
7953225 unverified
from unittest.mock import MagicMock
import pytest
from PIL import Image
import modes
@pytest.fixture
def fake_pipe():
"""Stand-in pipeline that records its __call__ args and returns a dummy image."""
pipe = MagicMock()
pipe.dit = MagicMock()
pipe.model_pool = MagicMock()
pipe.return_value = Image.new("RGB", (64, 64), color=(255, 176, 46))
return pipe
def test_t2i_turbo_builds_minimal_call(fake_pipe):
out, meta = modes.call_t2i(
fake_pipe,
params=dict(
prompt="a cat",
negative_prompt="",
model="Turbo",
steps=8,
cfg=1.0,
width=1024,
height=1024,
seed=42,
lora_path=None,
lora_strength=0.0,
),
)
fake_pipe.assert_called_once()
kwargs = fake_pipe.call_args.kwargs
assert kwargs["prompt"] == "a cat"
assert kwargs["cfg_scale"] == 1.0
assert kwargs["num_inference_steps"] == 8
assert kwargs["width"] == 1024
assert kwargs["seed"] == 42
assert kwargs["sigma_shift"] == 3.0
assert "negative_prompt" not in kwargs or not kwargs.get("negative_prompt")
assert meta["model"] == "Turbo"
assert meta["steps"] == 8
assert isinstance(out, Image.Image)
def test_t2i_base_passes_negative_prompt_and_cfg4(fake_pipe):
modes.call_t2i(
fake_pipe,
params=dict(
prompt="a cat",
negative_prompt="blurry, lowres",
model="Base",
steps=25,
cfg=4.0,
width=1024,
height=1024,
seed=42,
lora_path=None,
lora_strength=0.0,
),
)
kwargs = fake_pipe.call_args.kwargs
assert kwargs["negative_prompt"] == "blurry, lowres"
assert kwargs["cfg_scale"] == 4.0
assert kwargs["num_inference_steps"] == 25
def test_t2i_swaps_transformer_via_pool_index(fake_pipe):
"""Base picks pool.model[0]; Turbo picks pool.model[1] (load-order indexed)."""
base_dit = object()
turbo_dit = object()
# Two z_image_dit entries in load order: Base first, Turbo second.
fake_pipe._zis_pool.model = [base_dit, turbo_dit, "vae_decoder_obj"]
fake_pipe._zis_pool.model_name = ["z_image_dit", "z_image_dit", "flux_vae_decoder"]
modes.call_t2i(
fake_pipe,
params=dict(
prompt="x",
negative_prompt="",
model="Base",
steps=25,
cfg=4.0,
width=1024,
height=1024,
seed=0,
lora_path=None,
lora_strength=0.0,
),
)
assert fake_pipe.dit is base_dit
modes.call_t2i(
fake_pipe,
params=dict(
prompt="x",
negative_prompt="",
model="Turbo",
steps=8,
cfg=1.0,
width=1024,
height=1024,
seed=0,
lora_path=None,
lora_strength=0.0,
),
)
assert fake_pipe.dit is turbo_dit
def test_controlnet_calls_preprocessor_then_pipeline(fake_pipe, monkeypatch):
canny_called = []
def fake_run(mode, img):
canny_called.append((mode, img.size))
return img # passthrough for test
monkeypatch.setattr(modes, "preprocessors", type("P", (), {"run": staticmethod(fake_run)}))
input_image = Image.new("RGB", (1024, 1024))
_out, meta = modes.call_controlnet(
fake_pipe,
params=dict(
prompt="cinematic portrait",
input_image=input_image,
preprocessor="Canny",
controlnet_scale=1.0,
steps=9,
seed=42,
lora_path=None,
lora_strength=0.0,
),
)
assert canny_called == [("Canny", (1024, 1024))]
kwargs = fake_pipe.call_args.kwargs
assert "controlnet_inputs" in kwargs
cn_in = kwargs["controlnet_inputs"]
assert len(cn_in) == 1
assert cn_in[0].scale == 1.0
assert kwargs["num_inference_steps"] == 9
assert kwargs["cfg_scale"] == 1.0
assert meta["preprocessor"] == "Canny"
def test_controlnet_rejects_missing_input_image(fake_pipe):
with pytest.raises(ValueError):
modes.call_controlnet(
fake_pipe,
params=dict(
prompt="x",
input_image=None,
preprocessor="Canny",
controlnet_scale=1.0,
steps=9,
seed=0,
lora_path=None,
lora_strength=0.0,
),
)
def test_upscale_runs_realesrgan_then_pipeline(fake_pipe, monkeypatch):
calls = {"upscale": None}
def fake_2x(img, model_path):
calls["upscale"] = (img.size, str(model_path))
w, h = img.size
return img.resize((w * 2, h * 2), Image.LANCZOS)
monkeypatch.setattr(modes, "upscale", type("U", (), {"realesrgan_2x": staticmethod(fake_2x)}))
input_image = Image.new("RGB", (512, 512))
_out, meta = modes.call_upscale(
fake_pipe,
params=dict(
prompt="masterpiece, 8k",
input_image=input_image,
refine_steps=5,
refine_denoise=0.33,
seed=42,
esrgan_model_path="/fake/path/RealESRGAN_x4plus.pth",
),
)
assert calls["upscale"] == ((512, 512), "/fake/path/RealESRGAN_x4plus.pth")
kwargs = fake_pipe.call_args.kwargs
assert kwargs["input_image"].size == (1024, 1024) # 2x via fake_2x
assert kwargs["denoising_strength"] == 0.33
assert kwargs["num_inference_steps"] == 5
assert kwargs["cfg_scale"] == 1.0
# height/width must match the post-upscale image, else add_noise blows up on
# a shape mismatch between input_latents and noise.
assert kwargs["width"] == 1024
assert kwargs["height"] == 1024
assert meta["mode"] == "upscale"
def test_upscale_crops_to_multiple_of_16(fake_pipe, monkeypatch):
"""Regression: an upscaled image with non-aligned dims used to crash the pipeline
in add_noise because DiffSynth rounds height/width *up* to mod 16 for the noise
tensor while its VAE rounds *down* for the encoded latents. We crop to mod 16
before passing in, so both shapes agree."""
def fake_2x(img, model_path):
return Image.new("RGB", (1240, 728)) # 1240, 728 are NOT multiples of 16
monkeypatch.setattr(modes, "upscale", type("U", (), {"realesrgan_2x": staticmethod(fake_2x)}))
_out, meta = modes.call_upscale(
fake_pipe,
params=dict(
prompt="masterpiece, 8k",
input_image=Image.new("RGB", (620, 364)),
refine_steps=5,
refine_denoise=0.33,
seed=0,
esrgan_model_path="/fake/path/RealESRGAN_x4plus.pth",
),
)
kwargs = fake_pipe.call_args.kwargs
assert kwargs["width"] == 1232 # 1240 // 16 * 16
assert kwargs["height"] == 720 # 728 // 16 * 16
assert kwargs["input_image"].size == (1232, 720)
assert meta["width"] == 1232
assert meta["height"] == 720
def test_upscale_rejects_missing_image(fake_pipe):
with pytest.raises(ValueError):
modes.call_upscale(
fake_pipe,
params=dict(
prompt="x",
input_image=None,
refine_steps=5,
refine_denoise=0.33,
seed=0,
esrgan_model_path="/fake.pth",
),
)
def test_controlnet_falls_back_when_preprocessor_raises(fake_pipe, monkeypatch):
def boom(mode, img):
raise RuntimeError("preprocessor exploded")
monkeypatch.setattr(modes, "preprocessors", type("P", (), {"run": staticmethod(boom)}))
input_image = Image.new("RGB", (512, 512))
_out, _meta = modes.call_controlnet(
fake_pipe,
params=dict(
prompt="x",
input_image=input_image,
preprocessor="Canny",
controlnet_scale=1.0,
steps=9,
seed=0,
lora_path=None,
lora_strength=0.0,
),
)
# Pipeline still ran — fallback to raw input
kwargs = fake_pipe.call_args.kwargs
cn_in = kwargs["controlnet_inputs"]
assert cn_in[0].image is input_image # the raw input, not a preprocessed image