Spaces:
Running on Zero
Running on Zero
File size: 5,717 Bytes
8894ed9 9a5065c 8894ed9 3b83775 8894ed9 9a5065c 8894ed9 9a5065c 8894ed9 9a5065c 8894ed9 76862de 8894ed9 76862de 8894ed9 3b83775 9a5065c 3b83775 0cf8ffc 3b83775 0cf8ffc 3b83775 9a5065c 3b83775 0cf8ffc 3b83775 0cf8ffc 3b83775 9a5065c 3b83775 9a5065c 3b83775 76862de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """ZImageStudioBackend — wraps the DiffSynth pipeline; applies @spaces.GPU on HF Spaces."""
from __future__ import annotations
import os
from typing import Any
# Spaces import is optional — running locally we don't have it.
try:
import spaces # type: ignore
except ImportError:
spaces = None # type: ignore[assignment]
import modes
_BASE_DURATION_S: dict[str, int] = {
"t2i": 20, # fixed setup + decode
"controlnet": 30, # + preprocessor + control patch
"upscale": 50, # + realesrgan pixel-space step
}
_PER_STEP_S: dict[tuple[str, str], float] = {
("t2i", "Base"): 2.4,
("t2i", "Turbo"): 1.6,
("controlnet", "Turbo"): 2.0,
("upscale", "Turbo"): 1.6,
}
def duration_for(
mode: str,
params: dict[str, Any],
multiplier: float = 1.0,
) -> int:
"""Estimate ZeroGPU duration for a request. Pure function; clamped to [60, 180]."""
model = params.get("model", "Turbo")
steps = int(params.get("steps") or params.get("refine_steps") or 8)
width = int(params.get("width", 1024))
height = int(params.get("height", 1024))
eff_multiplier = float(params.get("__retry_multiplier__", multiplier))
base = _BASE_DURATION_S.get(mode, 30)
per_step = _PER_STEP_S.get((mode, model), _PER_STEP_S.get((mode, "Turbo"), 1.6))
size_factor = (width * height) / (1024 * 1024)
cold_buffer = 15 # CPU→GPU copy on first call after a quiet period
est = (base + per_step * steps + cold_buffer) * size_factor * eff_multiplier
return max(60, min(int(est), 180))
def _identity(fn):
return fn
_ON_SPACES = bool(os.environ.get("SPACES_ZERO_GPU"))
_GPU = (
spaces.GPU(duration=lambda *a, **kw: duration_for(*a[1:3], **kw))
if (spaces is not None and _ON_SPACES)
else _identity
)
def _build_pipeline() -> Any:
"""Construct a ZImagePipeline carrying BOTH Base and Turbo transformers.
DiffSynth's ``ZImagePipeline.from_pretrained`` builds a fresh ``ModelPool``
locally and throws it away after attaching ``pipe.dit`` etc. — so a later
transformer swap has nothing to switch between. We replicate the same
initialization manually and keep the pool on ``pipe._zis_pool`` so
:func:`modes._swap_transformer` can flip ``pipe.dit`` between the two
``z_image_dit`` entries (Base loaded first, Turbo second per MODEL_CONFIGS).
"""
import torch
from diffsynth.pipelines.z_image import ZImagePipeline
from transformers import AutoTokenizer
import models
device = models.auto_device()
vram_cfg: dict[str, Any] = {}
if device != "cpu":
vram_cfg = dict(
offload_dtype=torch.bfloat16,
offload_device="cpu",
onload_dtype=torch.bfloat16,
onload_device="cpu",
preparing_dtype=torch.bfloat16,
preparing_device=device,
computation_dtype=torch.bfloat16,
computation_device=device,
)
pipe = ZImagePipeline(device=device, torch_dtype=torch.bfloat16)
# Load every safetensors listed in MODEL_CONFIGS — both transformers + shared
# text encoder + VAE + controlnet — into one pool.
pool = pipe.download_and_load_models(
models.build_diffsynth_configs(vram_cfg=vram_cfg),
vram_limit=models.vram_limit_for(device),
)
pipe._zis_pool = pool
pipe.text_encoder = pool.fetch_model("z_image_text_encoder")
pipe.dit = pool.fetch_model("z_image_dit") # first match = Base per load order
pipe.vae_encoder = pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = pool.fetch_model("flux_vae_decoder")
pipe.controlnet = pool.fetch_model("z_image_controlnet")
# Optional image encoders that DiffSynth's ZImagePipeline references but
# aren't in our preload (Omni / image2lora). fetch_model returns None when
# absent — that's the documented "not an error" path.
pipe.image_encoder = pool.fetch_model("siglip_vision_model_428m")
pipe.siglip2_image_encoder = pool.fetch_model("siglip2_image_encoder")
pipe.dinov3_image_encoder = pool.fetch_model("dinov3_image_encoder")
pipe.image2lora_style = pool.fetch_model("z_image_image2lora_style")
# Tokenizer (Qwen3-4B tokenizer dir under Z-Image)
tok_cfg = models.build_diffsynth_configs((models.TOKENIZER_CONFIG,), vram_cfg=None)[0]
tok_cfg.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tok_cfg.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
_DISPATCH = {
"t2i": modes.call_t2i,
"controlnet": modes.call_controlnet,
"upscale": modes.call_upscale,
}
class ZImageStudioBackend:
"""One-process backend wrapping the DiffSynth ZImagePipeline."""
def __init__(self) -> None:
self.pipeline = _build_pipeline()
@_GPU
def generate(self, mode: str, params: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
handler = _DISPATCH.get(mode)
if handler is None:
raise ValueError(f"unknown mode: {mode!r}; expected one of {list(_DISPATCH)}")
return handler(self.pipeline, params)
def generate_with_retry(
backend_instance: ZImageStudioBackend,
mode: str,
params: dict[str, Any],
) -> tuple[Any, dict[str, Any]]:
"""Call backend_instance.generate; on ZeroGPU timeout, retry once with 2x duration budget."""
try:
return backend_instance.generate(mode, params)
except Exception as e:
msg = str(e).lower()
if "gpu task aborted" in msg or ("gpu" in msg and "aborted" in msg):
retry_params = dict(params)
retry_params["__retry_multiplier__"] = 2.0
return backend_instance.generate(mode, retry_params)
raise
|