z-image-studio / backend.py
techfreakworm's picture
fix: pool-stashed transformer swap + MPS-safe vram + corrected model-zoo anchor
0cf8ffc unverified
"""ZImageStudioBackend — wraps the DiffSynth pipeline; applies @spaces.GPU on HF Spaces."""
from __future__ import annotations
import os
from typing import Any
# Spaces import is optional — running locally we don't have it.
try:
import spaces # type: ignore
except ImportError:
spaces = None # type: ignore[assignment]
import modes
_BASE_DURATION_S: dict[str, int] = {
"t2i": 20, # fixed setup + decode
"controlnet": 30, # + preprocessor + control patch
"upscale": 50, # + realesrgan pixel-space step
}
_PER_STEP_S: dict[tuple[str, str], float] = {
("t2i", "Base"): 2.4,
("t2i", "Turbo"): 1.6,
("controlnet", "Turbo"): 2.0,
("upscale", "Turbo"): 1.6,
}
def duration_for(
mode: str,
params: dict[str, Any],
multiplier: float = 1.0,
) -> int:
"""Estimate ZeroGPU duration for a request. Pure function; clamped to [60, 180]."""
model = params.get("model", "Turbo")
steps = int(params.get("steps") or params.get("refine_steps") or 8)
width = int(params.get("width", 1024))
height = int(params.get("height", 1024))
eff_multiplier = float(params.get("__retry_multiplier__", multiplier))
base = _BASE_DURATION_S.get(mode, 30)
per_step = _PER_STEP_S.get((mode, model), _PER_STEP_S.get((mode, "Turbo"), 1.6))
size_factor = (width * height) / (1024 * 1024)
cold_buffer = 15 # CPU→GPU copy on first call after a quiet period
est = (base + per_step * steps + cold_buffer) * size_factor * eff_multiplier
return max(60, min(int(est), 180))
def _identity(fn):
return fn
_ON_SPACES = bool(os.environ.get("SPACES_ZERO_GPU"))
_GPU = (
spaces.GPU(duration=lambda *a, **kw: duration_for(*a[1:3], **kw))
if (spaces is not None and _ON_SPACES)
else _identity
)
def _build_pipeline() -> Any:
"""Construct a ZImagePipeline carrying BOTH Base and Turbo transformers.
DiffSynth's ``ZImagePipeline.from_pretrained`` builds a fresh ``ModelPool``
locally and throws it away after attaching ``pipe.dit`` etc. — so a later
transformer swap has nothing to switch between. We replicate the same
initialization manually and keep the pool on ``pipe._zis_pool`` so
:func:`modes._swap_transformer` can flip ``pipe.dit`` between the two
``z_image_dit`` entries (Base loaded first, Turbo second per MODEL_CONFIGS).
"""
import torch
from diffsynth.pipelines.z_image import ZImagePipeline
from transformers import AutoTokenizer
import models
device = models.auto_device()
vram_cfg: dict[str, Any] = {}
if device != "cpu":
vram_cfg = dict(
offload_dtype=torch.bfloat16,
offload_device="cpu",
onload_dtype=torch.bfloat16,
onload_device="cpu",
preparing_dtype=torch.bfloat16,
preparing_device=device,
computation_dtype=torch.bfloat16,
computation_device=device,
)
pipe = ZImagePipeline(device=device, torch_dtype=torch.bfloat16)
# Load every safetensors listed in MODEL_CONFIGS — both transformers + shared
# text encoder + VAE + controlnet — into one pool.
pool = pipe.download_and_load_models(
models.build_diffsynth_configs(vram_cfg=vram_cfg),
vram_limit=models.vram_limit_for(device),
)
pipe._zis_pool = pool
pipe.text_encoder = pool.fetch_model("z_image_text_encoder")
pipe.dit = pool.fetch_model("z_image_dit") # first match = Base per load order
pipe.vae_encoder = pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = pool.fetch_model("flux_vae_decoder")
pipe.controlnet = pool.fetch_model("z_image_controlnet")
# Optional image encoders that DiffSynth's ZImagePipeline references but
# aren't in our preload (Omni / image2lora). fetch_model returns None when
# absent — that's the documented "not an error" path.
pipe.image_encoder = pool.fetch_model("siglip_vision_model_428m")
pipe.siglip2_image_encoder = pool.fetch_model("siglip2_image_encoder")
pipe.dinov3_image_encoder = pool.fetch_model("dinov3_image_encoder")
pipe.image2lora_style = pool.fetch_model("z_image_image2lora_style")
# Tokenizer (Qwen3-4B tokenizer dir under Z-Image)
tok_cfg = models.build_diffsynth_configs((models.TOKENIZER_CONFIG,), vram_cfg=None)[0]
tok_cfg.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tok_cfg.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
_DISPATCH = {
"t2i": modes.call_t2i,
"controlnet": modes.call_controlnet,
"upscale": modes.call_upscale,
}
class ZImageStudioBackend:
"""One-process backend wrapping the DiffSynth ZImagePipeline."""
def __init__(self) -> None:
self.pipeline = _build_pipeline()
@_GPU
def generate(self, mode: str, params: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
handler = _DISPATCH.get(mode)
if handler is None:
raise ValueError(f"unknown mode: {mode!r}; expected one of {list(_DISPATCH)}")
return handler(self.pipeline, params)
def generate_with_retry(
backend_instance: ZImageStudioBackend,
mode: str,
params: dict[str, Any],
) -> tuple[Any, dict[str, Any]]:
"""Call backend_instance.generate; on ZeroGPU timeout, retry once with 2x duration budget."""
try:
return backend_instance.generate(mode, params)
except Exception as e:
msg = str(e).lower()
if "gpu task aborted" in msg or ("gpu" in msg and "aborted" in msg):
retry_params = dict(params)
retry_params["__retry_multiplier__"] = 2.0
return backend_instance.generate(mode, retry_params)
raise