Spaces:
Running on Zero
Running on Zero
feat(backend): ZImageStudioBackend with @spaces.GPU and mode dispatch
Browse files- backend.py +61 -0
- tests/test_backend.py +46 -0
backend.py
CHANGED
|
@@ -10,6 +10,8 @@ try:
|
|
| 10 |
except ImportError:
|
| 11 |
spaces = None # type: ignore[assignment]
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
_BASE_DURATION_S: dict[str, int] = {
|
| 15 |
"t2i": 20, # fixed setup + decode
|
|
@@ -42,3 +44,62 @@ def duration_for(
|
|
| 42 |
|
| 43 |
est = (base + per_step * steps + cold_buffer) * size_factor * multiplier
|
| 44 |
return max(60, min(int(est), 180))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
except ImportError:
|
| 11 |
spaces = None # type: ignore[assignment]
|
| 12 |
|
| 13 |
+
import modes
|
| 14 |
+
|
| 15 |
|
| 16 |
_BASE_DURATION_S: dict[str, int] = {
|
| 17 |
"t2i": 20, # fixed setup + decode
|
|
|
|
| 44 |
|
| 45 |
est = (base + per_step * steps + cold_buffer) * size_factor * multiplier
|
| 46 |
return max(60, min(int(est), 180))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _identity(fn):
|
| 50 |
+
return fn
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_ON_SPACES = bool(os.environ.get("SPACES_ZERO_GPU"))
|
| 54 |
+
_GPU = spaces.GPU(duration=lambda *a, **kw: duration_for(*a[1:3], **kw)) \
|
| 55 |
+
if (spaces is not None and _ON_SPACES) else _identity
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _build_pipeline() -> Any:
|
| 59 |
+
"""Construct the DiffSynth ZImagePipeline. Imported lazily to keep tests fast."""
|
| 60 |
+
import torch
|
| 61 |
+
from diffsynth.pipelines.z_image import ZImagePipeline
|
| 62 |
+
|
| 63 |
+
import models
|
| 64 |
+
|
| 65 |
+
device = models.auto_device()
|
| 66 |
+
vram_cfg: dict[str, Any] = {}
|
| 67 |
+
if device != "cpu":
|
| 68 |
+
vram_cfg = dict(
|
| 69 |
+
offload_dtype=torch.bfloat16, offload_device="cpu",
|
| 70 |
+
onload_dtype=torch.bfloat16, onload_device="cpu",
|
| 71 |
+
preparing_dtype=torch.bfloat16, preparing_device=device,
|
| 72 |
+
computation_dtype=torch.bfloat16, computation_device=device,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
pipe = ZImagePipeline.from_pretrained(
|
| 76 |
+
torch_dtype=torch.bfloat16,
|
| 77 |
+
device=device,
|
| 78 |
+
model_configs=models.build_diffsynth_configs(vram_cfg=vram_cfg),
|
| 79 |
+
tokenizer_config=models.build_diffsynth_configs(
|
| 80 |
+
(models.TOKENIZER_CONFIG,), vram_cfg=None,
|
| 81 |
+
)[0],
|
| 82 |
+
vram_limit=models.vram_limit_for(device),
|
| 83 |
+
)
|
| 84 |
+
return pipe
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
_DISPATCH = {
|
| 88 |
+
"t2i": modes.call_t2i,
|
| 89 |
+
"controlnet": modes.call_controlnet,
|
| 90 |
+
"upscale": modes.call_upscale,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ZImageStudioBackend:
|
| 95 |
+
"""One-process backend wrapping the DiffSynth ZImagePipeline."""
|
| 96 |
+
|
| 97 |
+
def __init__(self) -> None:
|
| 98 |
+
self.pipeline = _build_pipeline()
|
| 99 |
+
|
| 100 |
+
@_GPU
|
| 101 |
+
def generate(self, mode: str, params: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
|
| 102 |
+
handler = _DISPATCH.get(mode)
|
| 103 |
+
if handler is None:
|
| 104 |
+
raise ValueError(f"unknown mode: {mode!r}; expected one of {list(_DISPATCH)}")
|
| 105 |
+
return handler(self.pipeline, params)
|
tests/test_backend.py
CHANGED
|
@@ -32,3 +32,49 @@ def test_duration_upscale_has_realesrgan_overhead():
|
|
| 32 |
t2i = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
|
| 33 |
upsc = backend.duration_for(mode="upscale", params=dict(refine_steps=5, width=1024, height=1024))
|
| 34 |
assert upsc > t2i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
t2i = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
|
| 33 |
upsc = backend.duration_for(mode="upscale", params=dict(refine_steps=5, width=1024, height=1024))
|
| 34 |
assert upsc > t2i
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
from unittest.mock import MagicMock
|
| 38 |
+
|
| 39 |
+
import pytest
|
| 40 |
+
from PIL import Image
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@pytest.fixture
|
| 44 |
+
def fake_backend(monkeypatch):
|
| 45 |
+
"""A ZImageStudioBackend whose constructor doesn't actually build a pipeline."""
|
| 46 |
+
monkeypatch.setattr(backend, "_build_pipeline", lambda *a, **kw: MagicMock())
|
| 47 |
+
b = backend.ZImageStudioBackend()
|
| 48 |
+
b.pipeline.return_value = Image.new("RGB", (32, 32))
|
| 49 |
+
b.pipeline.dit = MagicMock()
|
| 50 |
+
b.pipeline.model_pool = MagicMock()
|
| 51 |
+
return b
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_backend_generate_routes_t2i(fake_backend):
|
| 55 |
+
img, meta = fake_backend.generate(
|
| 56 |
+
mode="t2i",
|
| 57 |
+
params=dict(prompt="cat", negative_prompt="", model="Turbo",
|
| 58 |
+
steps=8, cfg=1.0, width=1024, height=1024, seed=42,
|
| 59 |
+
lora_path=None, lora_strength=0.0),
|
| 60 |
+
)
|
| 61 |
+
assert isinstance(img, Image.Image)
|
| 62 |
+
assert meta["mode"] == "t2i"
|
| 63 |
+
assert meta["model"] == "Turbo"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_backend_generate_routes_controlnet(fake_backend, monkeypatch):
|
| 67 |
+
monkeypatch.setattr(backend.modes, "preprocessors",
|
| 68 |
+
type("P", (), {"run": staticmethod(lambda m, i: i)}))
|
| 69 |
+
img, meta = fake_backend.generate(
|
| 70 |
+
mode="controlnet",
|
| 71 |
+
params=dict(prompt="cat", input_image=Image.new("RGB", (64, 64)),
|
| 72 |
+
preprocessor="Canny", controlnet_scale=1.0,
|
| 73 |
+
steps=9, seed=0, lora_path=None, lora_strength=0.0),
|
| 74 |
+
)
|
| 75 |
+
assert meta["mode"] == "controlnet"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def test_backend_generate_unknown_mode_raises(fake_backend):
|
| 79 |
+
with pytest.raises(ValueError):
|
| 80 |
+
fake_backend.generate(mode="dance", params={})
|