Spaces:
Running on Zero
Running on Zero
feat(modes): t2i handler (base + turbo) with transformer swap and lora ctx
Browse files- modes.py +63 -0
- tests/test_modes.py +70 -0
modes.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mode handlers — pure functions over a ZImagePipeline + params dict."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, TypedDict
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
import lora
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class T2IParams(TypedDict, total=False):
|
| 13 |
+
prompt: str
|
| 14 |
+
negative_prompt: str
|
| 15 |
+
model: str # "Base" | "Turbo"
|
| 16 |
+
steps: int
|
| 17 |
+
cfg: float
|
| 18 |
+
width: int
|
| 19 |
+
height: int
|
| 20 |
+
seed: int
|
| 21 |
+
lora_path: Path | None
|
| 22 |
+
lora_strength: float
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _swap_transformer(pipe: Any, model_name: str) -> None:
|
| 26 |
+
"""Swap the active transformer in the pipeline's model pool."""
|
| 27 |
+
variant = "z_image" if model_name == "Base" else "z_image_turbo"
|
| 28 |
+
pipe.dit = pipe.model_pool.fetch_model("z_image_dit", variant=variant)
|
| 29 |
+
try:
|
| 30 |
+
pipe.dit._zis_variant = variant
|
| 31 |
+
except (AttributeError, RuntimeError):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def call_t2i(pipe: Any, params: T2IParams) -> tuple[Image.Image, dict[str, Any]]:
|
| 36 |
+
"""Text-to-image. Routes to base (cfg=4, 25 steps) or turbo (cfg=1, 8 steps)."""
|
| 37 |
+
model_name = params.get("model", "Turbo")
|
| 38 |
+
is_base = model_name == "Base"
|
| 39 |
+
_swap_transformer(pipe, model_name)
|
| 40 |
+
|
| 41 |
+
kwargs: dict[str, Any] = dict(
|
| 42 |
+
prompt=params["prompt"],
|
| 43 |
+
cfg_scale=float(params.get("cfg", 4.0 if is_base else 1.0)),
|
| 44 |
+
num_inference_steps=int(params.get("steps", 25 if is_base else 8)),
|
| 45 |
+
sigma_shift=3.0,
|
| 46 |
+
height=int(params.get("height", 1024)),
|
| 47 |
+
width=int(params.get("width", 1024)),
|
| 48 |
+
seed=int(params.get("seed", 0)),
|
| 49 |
+
)
|
| 50 |
+
if is_base and params.get("negative_prompt"):
|
| 51 |
+
kwargs["negative_prompt"] = params["negative_prompt"]
|
| 52 |
+
|
| 53 |
+
with lora.applied_lora(pipe, params.get("lora_path"), params.get("lora_strength", 0.0)):
|
| 54 |
+
image = pipe(**kwargs)
|
| 55 |
+
|
| 56 |
+
meta = dict(
|
| 57 |
+
mode="t2i", model=model_name,
|
| 58 |
+
steps=kwargs["num_inference_steps"], cfg=kwargs["cfg_scale"],
|
| 59 |
+
seed=kwargs["seed"], width=kwargs["width"], height=kwargs["height"],
|
| 60 |
+
lora=str(params.get("lora_path")) if params.get("lora_path") else None,
|
| 61 |
+
lora_strength=params.get("lora_strength", 0.0),
|
| 62 |
+
)
|
| 63 |
+
return image, meta
|
tests/test_modes.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
import modes
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pytest.fixture
|
| 10 |
+
def fake_pipe():
|
| 11 |
+
"""Stand-in pipeline that records its __call__ args and returns a dummy image."""
|
| 12 |
+
pipe = MagicMock()
|
| 13 |
+
pipe.dit = MagicMock()
|
| 14 |
+
pipe.model_pool = MagicMock()
|
| 15 |
+
pipe.return_value = Image.new("RGB", (64, 64), color=(255, 176, 46))
|
| 16 |
+
return pipe
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_t2i_turbo_builds_minimal_call(fake_pipe):
|
| 20 |
+
out, meta = modes.call_t2i(
|
| 21 |
+
fake_pipe,
|
| 22 |
+
params=dict(
|
| 23 |
+
prompt="a cat",
|
| 24 |
+
negative_prompt="",
|
| 25 |
+
model="Turbo",
|
| 26 |
+
steps=8, cfg=1.0,
|
| 27 |
+
width=1024, height=1024,
|
| 28 |
+
seed=42,
|
| 29 |
+
lora_path=None, lora_strength=0.0,
|
| 30 |
+
),
|
| 31 |
+
)
|
| 32 |
+
fake_pipe.assert_called_once()
|
| 33 |
+
kwargs = fake_pipe.call_args.kwargs
|
| 34 |
+
assert kwargs["prompt"] == "a cat"
|
| 35 |
+
assert kwargs["cfg_scale"] == 1.0
|
| 36 |
+
assert kwargs["num_inference_steps"] == 8
|
| 37 |
+
assert kwargs["width"] == 1024
|
| 38 |
+
assert kwargs["seed"] == 42
|
| 39 |
+
assert kwargs["sigma_shift"] == 3.0
|
| 40 |
+
assert "negative_prompt" not in kwargs or not kwargs.get("negative_prompt")
|
| 41 |
+
assert meta["model"] == "Turbo"
|
| 42 |
+
assert meta["steps"] == 8
|
| 43 |
+
assert isinstance(out, Image.Image)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_t2i_base_passes_negative_prompt_and_cfg4(fake_pipe):
|
| 47 |
+
modes.call_t2i(
|
| 48 |
+
fake_pipe,
|
| 49 |
+
params=dict(
|
| 50 |
+
prompt="a cat", negative_prompt="blurry, lowres",
|
| 51 |
+
model="Base", steps=25, cfg=4.0,
|
| 52 |
+
width=1024, height=1024, seed=42,
|
| 53 |
+
lora_path=None, lora_strength=0.0,
|
| 54 |
+
),
|
| 55 |
+
)
|
| 56 |
+
kwargs = fake_pipe.call_args.kwargs
|
| 57 |
+
assert kwargs["negative_prompt"] == "blurry, lowres"
|
| 58 |
+
assert kwargs["cfg_scale"] == 4.0
|
| 59 |
+
assert kwargs["num_inference_steps"] == 25
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def test_t2i_swaps_transformer_via_model_pool(fake_pipe):
|
| 63 |
+
modes.call_t2i(
|
| 64 |
+
fake_pipe,
|
| 65 |
+
params=dict(prompt="x", negative_prompt="", model="Base", steps=25, cfg=4.0,
|
| 66 |
+
width=1024, height=1024, seed=0, lora_path=None, lora_strength=0.0),
|
| 67 |
+
)
|
| 68 |
+
fake_pipe.model_pool.fetch_model.assert_called()
|
| 69 |
+
call = fake_pipe.model_pool.fetch_model.call_args
|
| 70 |
+
assert call.args[0] == "z_image_dit"
|