techfreakworm commited on
Commit
8894ed9
·
unverified ·
1 Parent(s): 84d00fe

feat(backend): zerogpu duration estimator (clamped 60-180s)

Browse files
Files changed (2) hide show
  1. backend.py +44 -0
  2. tests/test_backend.py +34 -0
backend.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ZImageStudioBackend — wraps the DiffSynth pipeline; applies @spaces.GPU on HF Spaces."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from typing import Any
6
+
7
+ # Spaces import is optional — running locally we don't have it.
8
+ try:
9
+ import spaces # type: ignore
10
+ except ImportError:
11
+ spaces = None # type: ignore[assignment]
12
+
13
+
14
+ _BASE_DURATION_S: dict[str, int] = {
15
+ "t2i": 20, # fixed setup + decode
16
+ "controlnet": 30, # + preprocessor + control patch
17
+ "upscale": 50, # + realesrgan pixel-space step
18
+ }
19
+ _PER_STEP_S: dict[tuple[str, str], float] = {
20
+ ("t2i", "Base"): 2.4,
21
+ ("t2i", "Turbo"): 1.6,
22
+ ("controlnet", "Turbo"): 2.0,
23
+ ("upscale", "Turbo"): 1.6,
24
+ }
25
+
26
+
27
+ def duration_for(
28
+ mode: str,
29
+ params: dict[str, Any],
30
+ multiplier: float = 1.0,
31
+ ) -> int:
32
+ """Estimate ZeroGPU duration for a request. Pure function; clamped to [60, 180]."""
33
+ model = params.get("model", "Turbo")
34
+ steps = int(params.get("steps") or params.get("refine_steps") or 8)
35
+ width = int(params.get("width", 1024))
36
+ height = int(params.get("height", 1024))
37
+
38
+ base = _BASE_DURATION_S.get(mode, 30)
39
+ per_step = _PER_STEP_S.get((mode, model), _PER_STEP_S.get((mode, "Turbo"), 1.6))
40
+ size_factor = (width * height) / (1024 * 1024)
41
+ cold_buffer = 15 # CPU→GPU copy on first call after a quiet period
42
+
43
+ est = (base + per_step * steps + cold_buffer) * size_factor * multiplier
44
+ return max(60, min(int(est), 180))
tests/test_backend.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import backend
2
+
3
+
4
+ def test_duration_t2i_turbo_is_short():
5
+ d = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
6
+ assert 60 <= d <= 90
7
+
8
+
9
+ def test_duration_t2i_base_is_longer():
10
+ d = backend.duration_for(mode="t2i", params=dict(model="Base", steps=25, width=1024, height=1024))
11
+ assert d > 60
12
+
13
+
14
+ def test_duration_clamps_at_180():
15
+ d = backend.duration_for(mode="t2i", params=dict(model="Base", steps=200, width=2048, height=2048))
16
+ assert d == 180
17
+
18
+
19
+ def test_duration_clamps_at_60():
20
+ d = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=1, width=256, height=256))
21
+ assert d == 60
22
+
23
+
24
+ def test_duration_multiplier_scales_up():
25
+ base = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
26
+ retry = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024),
27
+ multiplier=2.0)
28
+ assert retry > base
29
+
30
+
31
+ def test_duration_upscale_has_realesrgan_overhead():
32
+ t2i = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
33
+ upsc = backend.duration_for(mode="upscale", params=dict(refine_steps=5, width=1024, height=1024))
34
+ assert upsc > t2i