techfreakworm commited on
Commit
3762756
·
unverified ·
1 Parent(s): 8f6ce7f

feat(modes): controlnet handler (turbo + union 2.1 + preprocessor)

Browse files
Files changed (2) hide show
  1. modes.py +50 -0
  2. tests/test_modes.py +42 -0
modes.py CHANGED
@@ -7,6 +7,17 @@ from typing import Any, TypedDict
7
  from PIL import Image
8
 
9
  import lora
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class T2IParams(TypedDict, total=False):
@@ -61,3 +72,42 @@ def call_t2i(pipe: Any, params: T2IParams) -> tuple[Image.Image, dict[str, Any]]
61
  lora_strength=params.get("lora_strength", 0.0),
62
  )
63
  return image, meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from PIL import Image
8
 
9
  import lora
10
+ import preprocessors
11
+
12
+ try:
13
+ from diffsynth.diffusion.base_pipeline import ControlNetInput
14
+ except ImportError:
15
+ from dataclasses import dataclass
16
+
17
+ @dataclass
18
+ class ControlNetInput: # type: ignore[no-redef]
19
+ image: Any
20
+ scale: float = 1.0
21
 
22
 
23
  class T2IParams(TypedDict, total=False):
 
72
  lora_strength=params.get("lora_strength", 0.0),
73
  )
74
  return image, meta
75
+
76
+
77
+ def call_controlnet(pipe: Any, params: dict[str, Any]) -> tuple[Image.Image, dict[str, Any]]:
78
+ """ControlNet — Turbo + Z-Image-Turbo-Fun-Controlnet-Union-2.1."""
79
+ input_image: Image.Image | None = params.get("input_image")
80
+ if input_image is None:
81
+ raise ValueError("ControlNet mode requires an input image")
82
+
83
+ preproc_mode = params.get("preprocessor", "Canny")
84
+ control_image = preprocessors.run(preproc_mode, input_image)
85
+
86
+ _swap_transformer(pipe, "Turbo")
87
+
88
+ cn_input = ControlNetInput(image=control_image, scale=float(params.get("controlnet_scale", 1.0)))
89
+
90
+ kwargs: dict[str, Any] = dict(
91
+ prompt=params["prompt"],
92
+ cfg_scale=1.0,
93
+ num_inference_steps=int(params.get("steps", 9)),
94
+ sigma_shift=3.0,
95
+ height=control_image.size[1],
96
+ width=control_image.size[0],
97
+ seed=int(params.get("seed", 0)),
98
+ controlnet_inputs=[cn_input],
99
+ )
100
+
101
+ with lora.applied_lora(pipe, params.get("lora_path"), params.get("lora_strength", 0.0)):
102
+ image = pipe(**kwargs)
103
+
104
+ meta = dict(
105
+ mode="controlnet", model="Turbo",
106
+ preprocessor=preproc_mode,
107
+ controlnet_scale=cn_input.scale,
108
+ steps=kwargs["num_inference_steps"], cfg=1.0,
109
+ seed=kwargs["seed"], width=kwargs["width"], height=kwargs["height"],
110
+ lora=str(params.get("lora_path")) if params.get("lora_path") else None,
111
+ lora_strength=params.get("lora_strength", 0.0),
112
+ )
113
+ return image, meta
tests/test_modes.py CHANGED
@@ -68,3 +68,45 @@ def test_t2i_swaps_transformer_via_model_pool(fake_pipe):
68
  fake_pipe.model_pool.fetch_model.assert_called()
69
  call = fake_pipe.model_pool.fetch_model.call_args
70
  assert call.args[0] == "z_image_dit"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  fake_pipe.model_pool.fetch_model.assert_called()
69
  call = fake_pipe.model_pool.fetch_model.call_args
70
  assert call.args[0] == "z_image_dit"
71
+
72
+
73
+ def test_controlnet_calls_preprocessor_then_pipeline(fake_pipe, monkeypatch):
74
+ canny_called = []
75
+ def fake_run(mode, img):
76
+ canny_called.append((mode, img.size))
77
+ return img # passthrough for test
78
+ monkeypatch.setattr(modes, "preprocessors", type("P", (), {"run": staticmethod(fake_run)}))
79
+
80
+ input_image = Image.new("RGB", (1024, 1024))
81
+ out, meta = modes.call_controlnet(
82
+ fake_pipe,
83
+ params=dict(
84
+ prompt="cinematic portrait",
85
+ input_image=input_image,
86
+ preprocessor="Canny",
87
+ controlnet_scale=1.0,
88
+ steps=9,
89
+ seed=42,
90
+ lora_path=None, lora_strength=0.0,
91
+ ),
92
+ )
93
+
94
+ assert canny_called == [("Canny", (1024, 1024))]
95
+ kwargs = fake_pipe.call_args.kwargs
96
+ assert "controlnet_inputs" in kwargs
97
+ cn_in = kwargs["controlnet_inputs"]
98
+ assert len(cn_in) == 1
99
+ assert cn_in[0].scale == 1.0
100
+ assert kwargs["num_inference_steps"] == 9
101
+ assert kwargs["cfg_scale"] == 1.0
102
+ assert meta["preprocessor"] == "Canny"
103
+
104
+
105
+ def test_controlnet_rejects_missing_input_image(fake_pipe):
106
+ with pytest.raises(ValueError):
107
+ modes.call_controlnet(
108
+ fake_pipe,
109
+ params=dict(prompt="x", input_image=None, preprocessor="Canny",
110
+ controlnet_scale=1.0, steps=9, seed=0,
111
+ lora_path=None, lora_strength=0.0),
112
+ )