File size: 4,494 Bytes
9a5065c
 
 
 
 
8894ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5065c
 
 
8894ed9
 
 
 
 
 
 
3b83775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5065c
 
 
 
 
 
 
 
 
 
 
 
3b83775
 
 
 
 
 
 
9a5065c
 
3b83775
9a5065c
 
 
 
 
 
 
 
 
 
3b83775
 
 
 
 
 
 
76862de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from unittest.mock import MagicMock

import pytest
from PIL import Image

import backend


def test_duration_t2i_turbo_is_short():
    d = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
    assert 60 <= d <= 90


def test_duration_t2i_base_is_longer():
    d = backend.duration_for(mode="t2i", params=dict(model="Base", steps=25, width=1024, height=1024))
    assert d > 60


def test_duration_clamps_at_180():
    d = backend.duration_for(mode="t2i", params=dict(model="Base", steps=200, width=2048, height=2048))
    assert d == 180


def test_duration_clamps_at_60():
    d = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=1, width=256, height=256))
    assert d == 60


def test_duration_multiplier_scales_up():
    base = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
    retry = backend.duration_for(
        mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024), multiplier=2.0
    )
    assert retry > base


def test_duration_upscale_has_realesrgan_overhead():
    t2i = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
    upsc = backend.duration_for(mode="upscale", params=dict(refine_steps=5, width=1024, height=1024))
    assert upsc > t2i


@pytest.fixture
def fake_backend(monkeypatch):
    """A ZImageStudioBackend whose constructor doesn't actually build a pipeline."""
    monkeypatch.setattr(backend, "_build_pipeline", lambda *a, **kw: MagicMock())
    b = backend.ZImageStudioBackend()
    b.pipeline.return_value = Image.new("RGB", (32, 32))
    b.pipeline.dit = MagicMock()
    b.pipeline.model_pool = MagicMock()
    return b


def test_backend_generate_routes_t2i(fake_backend):
    img, meta = fake_backend.generate(
        mode="t2i",
        params=dict(
            prompt="cat",
            negative_prompt="",
            model="Turbo",
            steps=8,
            cfg=1.0,
            width=1024,
            height=1024,
            seed=42,
            lora_path=None,
            lora_strength=0.0,
        ),
    )
    assert isinstance(img, Image.Image)
    assert meta["mode"] == "t2i"
    assert meta["model"] == "Turbo"


def test_backend_generate_routes_controlnet(fake_backend, monkeypatch):
    monkeypatch.setattr(backend.modes, "preprocessors", type("P", (), {"run": staticmethod(lambda m, i: i)}))
    _img, meta = fake_backend.generate(
        mode="controlnet",
        params=dict(
            prompt="cat",
            input_image=Image.new("RGB", (64, 64)),
            preprocessor="Canny",
            controlnet_scale=1.0,
            steps=9,
            seed=0,
            lora_path=None,
            lora_strength=0.0,
        ),
    )
    assert meta["mode"] == "controlnet"


def test_backend_generate_unknown_mode_raises(fake_backend):
    with pytest.raises(ValueError):
        fake_backend.generate(mode="dance", params={})


def test_generate_with_retry_retries_on_gpu_aborted(fake_backend, monkeypatch):
    call_count = {"n": 0}
    original_generate = fake_backend.generate

    def flaky(mode, params):
        call_count["n"] += 1
        if call_count["n"] == 1:
            from gradio.exceptions import Error

            raise Error("GPU task aborted")
        return original_generate(mode, params)

    fake_backend.generate = flaky

    _img, meta = backend.generate_with_retry(
        fake_backend,
        mode="t2i",
        params=dict(
            prompt="x",
            negative_prompt="",
            model="Turbo",
            steps=8,
            cfg=1.0,
            width=1024,
            height=1024,
            seed=0,
            lora_path=None,
            lora_strength=0.0,
        ),
    )
    assert call_count["n"] == 2  # one fail + one retry
    assert meta["mode"] == "t2i"


def test_generate_with_retry_does_not_retry_other_errors(fake_backend):
    fake_backend.generate = lambda *a, **kw: (_ for _ in ()).throw(ValueError("not a gpu issue"))
    with pytest.raises(ValueError):
        backend.generate_with_retry(fake_backend, mode="t2i", params={})


def test_duration_honors_retry_multiplier_in_params():
    normal = backend.duration_for(mode="t2i", params=dict(model="Turbo", steps=8, width=1024, height=1024))
    retry = backend.duration_for(
        mode="t2i",
        params=dict(model="Turbo", steps=8, width=1024, height=1024, __retry_multiplier__=2.0),
    )
    assert retry > normal