Spaces:
Running on Zero
Running on Zero
feat(modes): controlnet handler (turbo + union 2.1 + preprocessor)
Browse files- modes.py +50 -0
- tests/test_modes.py +42 -0
modes.py
CHANGED
|
@@ -7,6 +7,17 @@ from typing import Any, TypedDict
|
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
import lora
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class T2IParams(TypedDict, total=False):
|
|
@@ -61,3 +72,42 @@ def call_t2i(pipe: Any, params: T2IParams) -> tuple[Image.Image, dict[str, Any]]
|
|
| 61 |
lora_strength=params.get("lora_strength", 0.0),
|
| 62 |
)
|
| 63 |
return image, meta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
import lora
|
| 10 |
+
import preprocessors
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
| 14 |
+
except ImportError:
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class ControlNetInput: # type: ignore[no-redef]
|
| 19 |
+
image: Any
|
| 20 |
+
scale: float = 1.0
|
| 21 |
|
| 22 |
|
| 23 |
class T2IParams(TypedDict, total=False):
|
|
|
|
| 72 |
lora_strength=params.get("lora_strength", 0.0),
|
| 73 |
)
|
| 74 |
return image, meta
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def call_controlnet(pipe: Any, params: dict[str, Any]) -> tuple[Image.Image, dict[str, Any]]:
|
| 78 |
+
"""ControlNet — Turbo + Z-Image-Turbo-Fun-Controlnet-Union-2.1."""
|
| 79 |
+
input_image: Image.Image | None = params.get("input_image")
|
| 80 |
+
if input_image is None:
|
| 81 |
+
raise ValueError("ControlNet mode requires an input image")
|
| 82 |
+
|
| 83 |
+
preproc_mode = params.get("preprocessor", "Canny")
|
| 84 |
+
control_image = preprocessors.run(preproc_mode, input_image)
|
| 85 |
+
|
| 86 |
+
_swap_transformer(pipe, "Turbo")
|
| 87 |
+
|
| 88 |
+
cn_input = ControlNetInput(image=control_image, scale=float(params.get("controlnet_scale", 1.0)))
|
| 89 |
+
|
| 90 |
+
kwargs: dict[str, Any] = dict(
|
| 91 |
+
prompt=params["prompt"],
|
| 92 |
+
cfg_scale=1.0,
|
| 93 |
+
num_inference_steps=int(params.get("steps", 9)),
|
| 94 |
+
sigma_shift=3.0,
|
| 95 |
+
height=control_image.size[1],
|
| 96 |
+
width=control_image.size[0],
|
| 97 |
+
seed=int(params.get("seed", 0)),
|
| 98 |
+
controlnet_inputs=[cn_input],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
with lora.applied_lora(pipe, params.get("lora_path"), params.get("lora_strength", 0.0)):
|
| 102 |
+
image = pipe(**kwargs)
|
| 103 |
+
|
| 104 |
+
meta = dict(
|
| 105 |
+
mode="controlnet", model="Turbo",
|
| 106 |
+
preprocessor=preproc_mode,
|
| 107 |
+
controlnet_scale=cn_input.scale,
|
| 108 |
+
steps=kwargs["num_inference_steps"], cfg=1.0,
|
| 109 |
+
seed=kwargs["seed"], width=kwargs["width"], height=kwargs["height"],
|
| 110 |
+
lora=str(params.get("lora_path")) if params.get("lora_path") else None,
|
| 111 |
+
lora_strength=params.get("lora_strength", 0.0),
|
| 112 |
+
)
|
| 113 |
+
return image, meta
|
tests/test_modes.py
CHANGED
|
@@ -68,3 +68,45 @@ def test_t2i_swaps_transformer_via_model_pool(fake_pipe):
|
|
| 68 |
fake_pipe.model_pool.fetch_model.assert_called()
|
| 69 |
call = fake_pipe.model_pool.fetch_model.call_args
|
| 70 |
assert call.args[0] == "z_image_dit"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
fake_pipe.model_pool.fetch_model.assert_called()
|
| 69 |
call = fake_pipe.model_pool.fetch_model.call_args
|
| 70 |
assert call.args[0] == "z_image_dit"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def test_controlnet_calls_preprocessor_then_pipeline(fake_pipe, monkeypatch):
|
| 74 |
+
canny_called = []
|
| 75 |
+
def fake_run(mode, img):
|
| 76 |
+
canny_called.append((mode, img.size))
|
| 77 |
+
return img # passthrough for test
|
| 78 |
+
monkeypatch.setattr(modes, "preprocessors", type("P", (), {"run": staticmethod(fake_run)}))
|
| 79 |
+
|
| 80 |
+
input_image = Image.new("RGB", (1024, 1024))
|
| 81 |
+
out, meta = modes.call_controlnet(
|
| 82 |
+
fake_pipe,
|
| 83 |
+
params=dict(
|
| 84 |
+
prompt="cinematic portrait",
|
| 85 |
+
input_image=input_image,
|
| 86 |
+
preprocessor="Canny",
|
| 87 |
+
controlnet_scale=1.0,
|
| 88 |
+
steps=9,
|
| 89 |
+
seed=42,
|
| 90 |
+
lora_path=None, lora_strength=0.0,
|
| 91 |
+
),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
assert canny_called == [("Canny", (1024, 1024))]
|
| 95 |
+
kwargs = fake_pipe.call_args.kwargs
|
| 96 |
+
assert "controlnet_inputs" in kwargs
|
| 97 |
+
cn_in = kwargs["controlnet_inputs"]
|
| 98 |
+
assert len(cn_in) == 1
|
| 99 |
+
assert cn_in[0].scale == 1.0
|
| 100 |
+
assert kwargs["num_inference_steps"] == 9
|
| 101 |
+
assert kwargs["cfg_scale"] == 1.0
|
| 102 |
+
assert meta["preprocessor"] == "Canny"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def test_controlnet_rejects_missing_input_image(fake_pipe):
|
| 106 |
+
with pytest.raises(ValueError):
|
| 107 |
+
modes.call_controlnet(
|
| 108 |
+
fake_pipe,
|
| 109 |
+
params=dict(prompt="x", input_image=None, preprocessor="Canny",
|
| 110 |
+
controlnet_scale=1.0, steps=9, seed=0,
|
| 111 |
+
lora_path=None, lora_strength=0.0),
|
| 112 |
+
)
|