techfreakworm commited on
Commit
3b83775
·
unverified ·
1 Parent(s): 8894ed9

feat(backend): ZImageStudioBackend with @spaces.GPU and mode dispatch

Browse files
Files changed (2) hide show
  1. backend.py +61 -0
  2. tests/test_backend.py +46 -0
backend.py CHANGED
@@ -10,6 +10,8 @@ try:
10
  except ImportError:
11
  spaces = None # type: ignore[assignment]
12
 
 
 
13
 
14
  _BASE_DURATION_S: dict[str, int] = {
15
  "t2i": 20, # fixed setup + decode
@@ -42,3 +44,62 @@ def duration_for(
42
 
43
  est = (base + per_step * steps + cold_buffer) * size_factor * multiplier
44
  return max(60, min(int(est), 180))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  except ImportError:
11
  spaces = None # type: ignore[assignment]
12
 
13
+ import modes
14
+
15
 
16
  _BASE_DURATION_S: dict[str, int] = {
17
  "t2i": 20, # fixed setup + decode
 
44
 
45
  est = (base + per_step * steps + cold_buffer) * size_factor * multiplier
46
  return max(60, min(int(est), 180))
47
+
48
+
49
+ def _identity(fn):
50
+ return fn
51
+
52
+
53
+ _ON_SPACES = bool(os.environ.get("SPACES_ZERO_GPU"))
54
+ _GPU = spaces.GPU(duration=lambda *a, **kw: duration_for(*a[1:3], **kw)) \
55
+ if (spaces is not None and _ON_SPACES) else _identity
56
+
57
+
58
+ def _build_pipeline() -> Any:
59
+ """Construct the DiffSynth ZImagePipeline. Imported lazily to keep tests fast."""
60
+ import torch
61
+ from diffsynth.pipelines.z_image import ZImagePipeline
62
+
63
+ import models
64
+
65
+ device = models.auto_device()
66
+ vram_cfg: dict[str, Any] = {}
67
+ if device != "cpu":
68
+ vram_cfg = dict(
69
+ offload_dtype=torch.bfloat16, offload_device="cpu",
70
+ onload_dtype=torch.bfloat16, onload_device="cpu",
71
+ preparing_dtype=torch.bfloat16, preparing_device=device,
72
+ computation_dtype=torch.bfloat16, computation_device=device,
73
+ )
74
+
75
+ pipe = ZImagePipeline.from_pretrained(
76
+ torch_dtype=torch.bfloat16,
77
+ device=device,
78
+ model_configs=models.build_diffsynth_configs(vram_cfg=vram_cfg),
79
+ tokenizer_config=models.build_diffsynth_configs(
80
+ (models.TOKENIZER_CONFIG,), vram_cfg=None,
81
+ )[0],
82
+ vram_limit=models.vram_limit_for(device),
83
+ )
84
+ return pipe
85
+
86
+
87
+ _DISPATCH = {
88
+ "t2i": modes.call_t2i,
89
+ "controlnet": modes.call_controlnet,
90
+ "upscale": modes.call_upscale,
91
+ }
92
+
93
+
94
+ class ZImageStudioBackend:
95
+ """One-process backend wrapping the DiffSynth ZImagePipeline."""
96
+
97
+ def __init__(self) -> None:
98
+ self.pipeline = _build_pipeline()
99
+
100
+ @_GPU
101
+ def generate(self, mode: str, params: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
102
+ handler = _DISPATCH.get(mode)
103
+ if handler is None:
104
+ raise ValueError(f"unknown mode: {mode!r}; expected one of {list(_DISPATCH)}")
105
+ return handler(self.pipeline, params)
tests/test_backend.py CHANGED
@@ -32,3 +32,49 @@ def test_duration_upscale_has_realesrgan_overhead():
32
  t2i = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
33
  upsc = backend.duration_for(mode="upscale", params=dict(refine_steps=5, width=1024, height=1024))
34
  assert upsc > t2i
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  t2i = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
33
  upsc = backend.duration_for(mode="upscale", params=dict(refine_steps=5, width=1024, height=1024))
34
  assert upsc > t2i
35
+
36
+
37
+ from unittest.mock import MagicMock
38
+
39
+ import pytest
40
+ from PIL import Image
41
+
42
+
43
+ @pytest.fixture
44
+ def fake_backend(monkeypatch):
45
+ """A ZImageStudioBackend whose constructor doesn't actually build a pipeline."""
46
+ monkeypatch.setattr(backend, "_build_pipeline", lambda *a, **kw: MagicMock())
47
+ b = backend.ZImageStudioBackend()
48
+ b.pipeline.return_value = Image.new("RGB", (32, 32))
49
+ b.pipeline.dit = MagicMock()
50
+ b.pipeline.model_pool = MagicMock()
51
+ return b
52
+
53
+
54
+ def test_backend_generate_routes_t2i(fake_backend):
55
+ img, meta = fake_backend.generate(
56
+ mode="t2i",
57
+ params=dict(prompt="cat", negative_prompt="", model="Turbo",
58
+ steps=8, cfg=1.0, width=1024, height=1024, seed=42,
59
+ lora_path=None, lora_strength=0.0),
60
+ )
61
+ assert isinstance(img, Image.Image)
62
+ assert meta["mode"] == "t2i"
63
+ assert meta["model"] == "Turbo"
64
+
65
+
66
+ def test_backend_generate_routes_controlnet(fake_backend, monkeypatch):
67
+ monkeypatch.setattr(backend.modes, "preprocessors",
68
+ type("P", (), {"run": staticmethod(lambda m, i: i)}))
69
+ img, meta = fake_backend.generate(
70
+ mode="controlnet",
71
+ params=dict(prompt="cat", input_image=Image.new("RGB", (64, 64)),
72
+ preprocessor="Canny", controlnet_scale=1.0,
73
+ steps=9, seed=0, lora_path=None, lora_strength=0.0),
74
+ )
75
+ assert meta["mode"] == "controlnet"
76
+
77
+
78
+ def test_backend_generate_unknown_mode_raises(fake_backend):
79
+ with pytest.raises(ValueError):
80
+ fake_backend.generate(mode="dance", params={})