File size: 7,473 Bytes
613dab3
9a5065c
613dab3
 
 
9a5065c
261639d
613dab3
 
 
 
 
 
 
 
 
 
 
 
 
9a5065c
613dab3
 
 
 
 
 
 
0cf8ffc
613dab3
 
0cf8ffc
 
 
 
 
613dab3
 
 
 
0cf8ffc
 
 
 
 
 
 
613dab3
 
9a5065c
0cf8ffc
613dab3
 
 
 
 
 
 
 
 
 
 
9a5065c
613dab3
 
 
 
 
 
 
9a5065c
 
 
 
 
 
 
613dab3
9a5065c
613dab3
9a5065c
99302bc
9a5065c
 
 
613dab3
 
9a5065c
613dab3
 
 
 
 
 
 
 
 
 
 
 
9a5065c
613dab3
9a5065c
613dab3
261639d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99302bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Device autodetect, ZImagePipeline ModelConfig registry, and (Task 4) HF cache mirror."""

from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any

# Avoid importing torch at module load — keeps `import models` fast in CI.


def on_spaces() -> bool:
    """True iff we are running inside a Hugging Face ZeroGPU Space."""
    return bool(os.environ.get("SPACES_ZERO_GPU"))


def auto_device() -> str:
    """Detect the best available compute device."""
    import torch

    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def vram_limit_for(device: str, free_gb: float | None = None) -> float | None:
    """Conservative VRAM limit (GB) passed to DiffSynth's vram_management.

    - CUDA: keep a few GB headroom (loaded models + scratch).
    - MPS: ``None`` — PyTorch's MPS has no ``mem_get_info`` API, and DiffSynth's
      ``check_free_vram`` raises AttributeError when called on MPS. Returning
      ``None`` short-circuits the check (``vram/layers.py:195``) so module
      swapping still works without the gate.
    - CPU: 0.0 (no offload budget).
    """
    if device == "cpu":
        return 0.0
    if device == "mps":
        # PyTorch's MPS backend has no ``torch.mps.mem_get_info``. DiffSynth's
        # ``AutoWrappedModule.check_free_vram`` calls it and raises AttributeError.
        # Returning None short-circuits the gate at vram/layers.py:195 so we keep
        # CPU↔MPS module swapping (offload/onload) without the doomed check.
        return None
    # cuda
    if free_gb is None:
        import torch

        free_gb = torch.cuda.mem_get_info()[1] / (1024**3)
    return max(8.0, free_gb - 4.0)


@dataclass(frozen=True)
class ModelConfig:
    """Lightweight wrapper around DiffSynth's ModelConfig.

    Stored as plain data so this module imports cheaply in CI. The real
    ``diffsynth.core.ModelConfig`` instance is built on demand by
    :func:`build_diffsynth_configs`.
    """

    model_id: str
    origin_file_pattern: str
    description: str = ""


MODEL_CONFIGS: tuple[ModelConfig, ...] = (
    # Base
    ModelConfig("Tongyi-MAI/Z-Image", "transformer/*.safetensors", "Z-Image base transformer (25 steps, cfg=4)"),
    ModelConfig(
        "Tongyi-MAI/Z-Image", "text_encoder/*.safetensors", "Qwen3-4B text encoder — shared between base + turbo"
    ),
    ModelConfig(
        "Tongyi-MAI/Z-Image", "vae/diffusion_pytorch_model.safetensors", "Flux-family VAE — shared between base + turbo"
    ),
    # Turbo (transformer only — encoder + VAE come from the Z-Image entry above)
    ModelConfig("Tongyi-MAI/Z-Image-Turbo", "transformer/*.safetensors", "Z-Image-Turbo transformer (8 steps, cfg=1)"),
    # ControlNet Union 2.1 (eager preload per spec; can move to lazy if RAM is tight)
    ModelConfig(
        "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
        "Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors",
        "ControlNet Union 2.1 — canny/depth/pose",
    ),
)

TOKENIZER_CONFIG = ModelConfig("Tongyi-MAI/Z-Image", "tokenizer/", "Qwen3-4B tokenizer")


def build_diffsynth_configs(
    configs: tuple[ModelConfig, ...] = MODEL_CONFIGS,
    vram_cfg: dict[str, Any] | None = None,
) -> list[Any]:
    """Build DiffSynth ``ModelConfig`` instances from our lightweight dataclasses.

    Called at app boot; not at module import. ``vram_cfg`` is the disk-offload
    block (offload_dtype, offload_device, etc.) that DiffSynth's low-VRAM examples use.
    """
    from diffsynth.core import ModelConfig as DSConfig

    return [
        DSConfig(model_id=c.model_id, origin_file_pattern=c.origin_file_pattern, **(vram_cfg or {})) for c in configs
    ]


def mirror_preload_hf_cache(src_root: Path | str, dst_root: Path | str) -> None:
    """Mirror a read-only HF cache tree (preload_from_hub) into a writable tree.

    - ``blobs/<sha>`` files -> **hardlinked** (zero-copy, shared inode).
    - ``snapshots/<commit>/...`` symlinks -> **preserved** with original relative target.
    - ``refs/<branch>`` files -> **byte-copied** (HF lib overwrites on etag check).
    - Directories -> ``mkdir`` so the runtime user owns them.

    Falls back to ``symlink`` when ``os.link()`` raises EXDEV (cross-device).
    """
    import errno
    import shutil

    src_root = Path(src_root)
    dst_root = Path(dst_root)

    if not (src_root / "hub").exists():
        return  # nothing preloaded -- no-op

    for src_dir, _, files in os.walk(src_root / "hub"):
        rel = Path(src_dir).relative_to(src_root)
        dst_dir = dst_root / rel
        dst_dir.mkdir(parents=True, exist_ok=True)

        for name in files:
            src_path = Path(src_dir) / name
            dst_path = dst_dir / name
            if dst_path.exists():
                continue

            # Refs get byte-copied
            if "refs/" in str(rel).replace("\\", "/"):
                shutil.copy2(src_path, dst_path)
                continue

            # Symlinks (snapshot files) preserve their relative target
            if src_path.is_symlink():
                target = os.readlink(src_path)
                dst_path.symlink_to(target)
                continue

            # Regular files (blobs) hardlink with EXDEV fallback
            try:
                os.link(src_path, dst_path)
            except OSError as e:
                if e.errno == errno.EXDEV:
                    dst_path.symlink_to(src_path)
                else:
                    raise


def symlink_hf_cache_to_diffsynth_layout(cache_hub: Path | str, dest_root: Path | str) -> list[str]:
    """For each ``models--<org>--<repo>`` under ``cache_hub``, symlink the latest snapshot
    dir to ``dest_root/<org>/<repo>/`` — the layout DiffSynth's ModelConfig expects.

    DiffSynth's ``download()`` joins ``local_model_path`` with ``model_id`` and either
    finds matching files (skipping download) or fetches them. Putting symlinks at the
    expected location lets DiffSynth reuse our HF-cache snapshots instead of re-downloading.

    Returns the list of dest paths created. Idempotent: existing valid symlinks are kept.
    """
    cache_hub = Path(cache_hub)
    dest_root = Path(dest_root)
    if not cache_hub.is_dir():
        return []

    created: list[str] = []
    for entry in sorted(cache_hub.iterdir()):
        if not entry.is_dir() or not entry.name.startswith("models--"):
            continue
        # "models--Tongyi-MAI--Z-Image-Turbo" -> ("Tongyi-MAI", "Z-Image-Turbo")
        # Some repos contain "--" in their name; only split off the first segment.
        rest = entry.name[len("models--") :]
        parts = rest.split("--", 1)
        if len(parts) != 2:
            continue
        org, repo = parts

        snapshots = entry / "snapshots"
        if not snapshots.is_dir():
            continue
        sha_dirs = [d for d in snapshots.iterdir() if d.is_dir()]
        if not sha_dirs:
            continue
        # Newest by mtime — usually the only one for our preload + first-fetch flow.
        sha_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
        snap = sha_dirs[0]

        link = dest_root / org / repo
        if link.is_symlink() or link.exists():
            continue
        link.parent.mkdir(parents=True, exist_ok=True)
        link.symlink_to(snap)
        created.append(str(link))
    return created