File size: 5,717 Bytes
8894ed9
9a5065c
8894ed9
 
 
 
 
 
 
 
 
 
 
3b83775
 
8894ed9
9a5065c
 
 
8894ed9
 
9a5065c
8894ed9
 
9a5065c
8894ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
76862de
 
8894ed9
 
 
 
 
76862de
8894ed9
3b83775
 
 
 
 
 
 
9a5065c
 
 
 
 
3b83775
 
 
0cf8ffc
 
 
 
 
 
 
 
 
3b83775
 
0cf8ffc
3b83775
 
 
 
 
 
 
9a5065c
 
 
 
 
 
 
 
3b83775
 
0cf8ffc
 
 
 
 
 
3b83775
 
0cf8ffc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b83775
 
 
 
9a5065c
3b83775
9a5065c
3b83775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76862de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""ZImageStudioBackend — wraps the DiffSynth pipeline; applies @spaces.GPU on HF Spaces."""

from __future__ import annotations

import os
from typing import Any

# Spaces import is optional — running locally we don't have it.
try:
    import spaces  # type: ignore
except ImportError:
    spaces = None  # type: ignore[assignment]

import modes

_BASE_DURATION_S: dict[str, int] = {
    "t2i": 20,  # fixed setup + decode
    "controlnet": 30,  # + preprocessor + control patch
    "upscale": 50,  # + realesrgan pixel-space step
}
_PER_STEP_S: dict[tuple[str, str], float] = {
    ("t2i", "Base"): 2.4,
    ("t2i", "Turbo"): 1.6,
    ("controlnet", "Turbo"): 2.0,
    ("upscale", "Turbo"): 1.6,
}


def duration_for(
    mode: str,
    params: dict[str, Any],
    multiplier: float = 1.0,
) -> int:
    """Estimate ZeroGPU duration for a request. Pure function; clamped to [60, 180]."""
    model = params.get("model", "Turbo")
    steps = int(params.get("steps") or params.get("refine_steps") or 8)
    width = int(params.get("width", 1024))
    height = int(params.get("height", 1024))

    eff_multiplier = float(params.get("__retry_multiplier__", multiplier))

    base = _BASE_DURATION_S.get(mode, 30)
    per_step = _PER_STEP_S.get((mode, model), _PER_STEP_S.get((mode, "Turbo"), 1.6))
    size_factor = (width * height) / (1024 * 1024)
    cold_buffer = 15  # CPU→GPU copy on first call after a quiet period

    est = (base + per_step * steps + cold_buffer) * size_factor * eff_multiplier
    return max(60, min(int(est), 180))


def _identity(fn):
    return fn


_ON_SPACES = bool(os.environ.get("SPACES_ZERO_GPU"))
_GPU = (
    spaces.GPU(duration=lambda *a, **kw: duration_for(*a[1:3], **kw))
    if (spaces is not None and _ON_SPACES)
    else _identity
)


def _build_pipeline() -> Any:
    """Construct a ZImagePipeline carrying BOTH Base and Turbo transformers.

    DiffSynth's ``ZImagePipeline.from_pretrained`` builds a fresh ``ModelPool``
    locally and throws it away after attaching ``pipe.dit`` etc. — so a later
    transformer swap has nothing to switch between. We replicate the same
    initialization manually and keep the pool on ``pipe._zis_pool`` so
    :func:`modes._swap_transformer` can flip ``pipe.dit`` between the two
    ``z_image_dit`` entries (Base loaded first, Turbo second per MODEL_CONFIGS).
    """
    import torch
    from diffsynth.pipelines.z_image import ZImagePipeline
    from transformers import AutoTokenizer

    import models

    device = models.auto_device()
    vram_cfg: dict[str, Any] = {}
    if device != "cpu":
        vram_cfg = dict(
            offload_dtype=torch.bfloat16,
            offload_device="cpu",
            onload_dtype=torch.bfloat16,
            onload_device="cpu",
            preparing_dtype=torch.bfloat16,
            preparing_device=device,
            computation_dtype=torch.bfloat16,
            computation_device=device,
        )

    pipe = ZImagePipeline(device=device, torch_dtype=torch.bfloat16)

    # Load every safetensors listed in MODEL_CONFIGS — both transformers + shared
    # text encoder + VAE + controlnet — into one pool.
    pool = pipe.download_and_load_models(
        models.build_diffsynth_configs(vram_cfg=vram_cfg),
        vram_limit=models.vram_limit_for(device),
    )
    pipe._zis_pool = pool

    pipe.text_encoder = pool.fetch_model("z_image_text_encoder")
    pipe.dit = pool.fetch_model("z_image_dit")  # first match = Base per load order
    pipe.vae_encoder = pool.fetch_model("flux_vae_encoder")
    pipe.vae_decoder = pool.fetch_model("flux_vae_decoder")
    pipe.controlnet = pool.fetch_model("z_image_controlnet")
    # Optional image encoders that DiffSynth's ZImagePipeline references but
    # aren't in our preload (Omni / image2lora). fetch_model returns None when
    # absent — that's the documented "not an error" path.
    pipe.image_encoder = pool.fetch_model("siglip_vision_model_428m")
    pipe.siglip2_image_encoder = pool.fetch_model("siglip2_image_encoder")
    pipe.dinov3_image_encoder = pool.fetch_model("dinov3_image_encoder")
    pipe.image2lora_style = pool.fetch_model("z_image_image2lora_style")

    # Tokenizer (Qwen3-4B tokenizer dir under Z-Image)
    tok_cfg = models.build_diffsynth_configs((models.TOKENIZER_CONFIG,), vram_cfg=None)[0]
    tok_cfg.download_if_necessary()
    pipe.tokenizer = AutoTokenizer.from_pretrained(tok_cfg.path)

    pipe.vram_management_enabled = pipe.check_vram_management_state()
    return pipe


_DISPATCH = {
    "t2i": modes.call_t2i,
    "controlnet": modes.call_controlnet,
    "upscale": modes.call_upscale,
}


class ZImageStudioBackend:
    """One-process backend wrapping the DiffSynth ZImagePipeline."""

    def __init__(self) -> None:
        self.pipeline = _build_pipeline()

    @_GPU
    def generate(self, mode: str, params: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
        handler = _DISPATCH.get(mode)
        if handler is None:
            raise ValueError(f"unknown mode: {mode!r}; expected one of {list(_DISPATCH)}")
        return handler(self.pipeline, params)


def generate_with_retry(
    backend_instance: ZImageStudioBackend,
    mode: str,
    params: dict[str, Any],
) -> tuple[Any, dict[str, Any]]:
    """Call backend_instance.generate; on ZeroGPU timeout, retry once with 2x duration budget."""
    try:
        return backend_instance.generate(mode, params)
    except Exception as e:
        msg = str(e).lower()
        if "gpu task aborted" in msg or ("gpu" in msg and "aborted" in msg):
            retry_params = dict(params)
            retry_params["__retry_multiplier__"] = 2.0
            return backend_instance.generate(mode, retry_params)
        raise