techfreakworm commited on
Commit
8f6ce7f
·
unverified ·
1 Parent(s): b855333

feat(modes): t2i handler (base + turbo) with transformer swap and lora ctx

Browse files
Files changed (2) hide show
  1. modes.py +63 -0
  2. tests/test_modes.py +70 -0
modes.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mode handlers — pure functions over a ZImagePipeline + params dict."""
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+ from typing import Any, TypedDict
6
+
7
+ from PIL import Image
8
+
9
+ import lora
10
+
11
+
12
+ class T2IParams(TypedDict, total=False):
13
+ prompt: str
14
+ negative_prompt: str
15
+ model: str # "Base" | "Turbo"
16
+ steps: int
17
+ cfg: float
18
+ width: int
19
+ height: int
20
+ seed: int
21
+ lora_path: Path | None
22
+ lora_strength: float
23
+
24
+
25
+ def _swap_transformer(pipe: Any, model_name: str) -> None:
26
+ """Swap the active transformer in the pipeline's model pool."""
27
+ variant = "z_image" if model_name == "Base" else "z_image_turbo"
28
+ pipe.dit = pipe.model_pool.fetch_model("z_image_dit", variant=variant)
29
+ try:
30
+ pipe.dit._zis_variant = variant
31
+ except (AttributeError, RuntimeError):
32
+ pass
33
+
34
+
35
+ def call_t2i(pipe: Any, params: T2IParams) -> tuple[Image.Image, dict[str, Any]]:
36
+ """Text-to-image. Routes to base (cfg=4, 25 steps) or turbo (cfg=1, 8 steps)."""
37
+ model_name = params.get("model", "Turbo")
38
+ is_base = model_name == "Base"
39
+ _swap_transformer(pipe, model_name)
40
+
41
+ kwargs: dict[str, Any] = dict(
42
+ prompt=params["prompt"],
43
+ cfg_scale=float(params.get("cfg", 4.0 if is_base else 1.0)),
44
+ num_inference_steps=int(params.get("steps", 25 if is_base else 8)),
45
+ sigma_shift=3.0,
46
+ height=int(params.get("height", 1024)),
47
+ width=int(params.get("width", 1024)),
48
+ seed=int(params.get("seed", 0)),
49
+ )
50
+ if is_base and params.get("negative_prompt"):
51
+ kwargs["negative_prompt"] = params["negative_prompt"]
52
+
53
+ with lora.applied_lora(pipe, params.get("lora_path"), params.get("lora_strength", 0.0)):
54
+ image = pipe(**kwargs)
55
+
56
+ meta = dict(
57
+ mode="t2i", model=model_name,
58
+ steps=kwargs["num_inference_steps"], cfg=kwargs["cfg_scale"],
59
+ seed=kwargs["seed"], width=kwargs["width"], height=kwargs["height"],
60
+ lora=str(params.get("lora_path")) if params.get("lora_path") else None,
61
+ lora_strength=params.get("lora_strength", 0.0),
62
+ )
63
+ return image, meta
tests/test_modes.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import MagicMock
2
+
3
+ import pytest
4
+ from PIL import Image
5
+
6
+ import modes
7
+
8
+
9
+ @pytest.fixture
10
+ def fake_pipe():
11
+ """Stand-in pipeline that records its __call__ args and returns a dummy image."""
12
+ pipe = MagicMock()
13
+ pipe.dit = MagicMock()
14
+ pipe.model_pool = MagicMock()
15
+ pipe.return_value = Image.new("RGB", (64, 64), color=(255, 176, 46))
16
+ return pipe
17
+
18
+
19
+ def test_t2i_turbo_builds_minimal_call(fake_pipe):
20
+ out, meta = modes.call_t2i(
21
+ fake_pipe,
22
+ params=dict(
23
+ prompt="a cat",
24
+ negative_prompt="",
25
+ model="Turbo",
26
+ steps=8, cfg=1.0,
27
+ width=1024, height=1024,
28
+ seed=42,
29
+ lora_path=None, lora_strength=0.0,
30
+ ),
31
+ )
32
+ fake_pipe.assert_called_once()
33
+ kwargs = fake_pipe.call_args.kwargs
34
+ assert kwargs["prompt"] == "a cat"
35
+ assert kwargs["cfg_scale"] == 1.0
36
+ assert kwargs["num_inference_steps"] == 8
37
+ assert kwargs["width"] == 1024
38
+ assert kwargs["seed"] == 42
39
+ assert kwargs["sigma_shift"] == 3.0
40
+ assert "negative_prompt" not in kwargs or not kwargs.get("negative_prompt")
41
+ assert meta["model"] == "Turbo"
42
+ assert meta["steps"] == 8
43
+ assert isinstance(out, Image.Image)
44
+
45
+
46
+ def test_t2i_base_passes_negative_prompt_and_cfg4(fake_pipe):
47
+ modes.call_t2i(
48
+ fake_pipe,
49
+ params=dict(
50
+ prompt="a cat", negative_prompt="blurry, lowres",
51
+ model="Base", steps=25, cfg=4.0,
52
+ width=1024, height=1024, seed=42,
53
+ lora_path=None, lora_strength=0.0,
54
+ ),
55
+ )
56
+ kwargs = fake_pipe.call_args.kwargs
57
+ assert kwargs["negative_prompt"] == "blurry, lowres"
58
+ assert kwargs["cfg_scale"] == 4.0
59
+ assert kwargs["num_inference_steps"] == 25
60
+
61
+
62
+ def test_t2i_swaps_transformer_via_model_pool(fake_pipe):
63
+ modes.call_t2i(
64
+ fake_pipe,
65
+ params=dict(prompt="x", negative_prompt="", model="Base", steps=25, cfg=4.0,
66
+ width=1024, height=1024, seed=0, lora_path=None, lora_strength=0.0),
67
+ )
68
+ fake_pipe.model_pool.fetch_model.assert_called()
69
+ call = fake_pipe.model_pool.fetch_model.call_args
70
+ assert call.args[0] == "z_image_dit"