File size: 3,751 Bytes
bc2513c
9a5065c
bc2513c
 
 
 
9a5065c
8759971
bc2513c
 
9a5065c
bc2513c
2e18e13
bc2513c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8759971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e18e13
9a5065c
2e18e13
 
 
8759971
2e18e13
9a5065c
8759971
2e18e13
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""LoRA file validation and apply/revert context manager."""

from __future__ import annotations

import json
import struct
from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any

ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.", "diffusion_model.")


class LoRAValidationError(ValueError):
    """Raised when a LoRA safetensors file doesn't match Z-Image's key layout."""


@dataclass(frozen=True)
class LoRAInfo:
    path: Path
    rank: int
    target: str  # which submodule it applies to ("transformer" for Z-Image)
    size_bytes: int


def sniff(path: Path | str) -> LoRAInfo:
    """Read just the safetensors header to verify and infer rank + target.

    Doesn't load tensors. Doesn't allocate GPU memory. Cheap enough to call before
    @spaces.GPU fires.
    """
    path = Path(path)
    raw = path.read_bytes()
    if len(raw) < 8:
        raise LoRAValidationError(f"{path.name}: file too short to be safetensors")
    (header_len,) = struct.unpack("<Q", raw[:8])
    if header_len <= 0 or header_len + 8 > len(raw):
        raise LoRAValidationError(f"{path.name}: not a valid safetensors header")
    try:
        header = json.loads(raw[8 : 8 + header_len])
    except json.JSONDecodeError as e:
        raise LoRAValidationError(f"{path.name}: safetensors header is not JSON ({e})") from e

    tensor_keys = [k for k in header.keys() if not k.startswith("__")]
    if not tensor_keys:
        raise LoRAValidationError(f"{path.name}: no tensors in file")

    bad = [k for k in tensor_keys if not k.startswith(ZIMAGE_LORA_PREFIXES)]
    if bad:
        sample = bad[0]
        raise LoRAValidationError(
            f"{path.name}: unexpected key '{sample}' — Z-Image LoRAs must target "
            f"{ZIMAGE_LORA_PREFIXES} (got {len(bad)}/{len(tensor_keys)} mismatched keys)"
        )

    meta = header.get("__metadata__") or {}
    rank = int(meta.get("rank", 0))
    if not rank:
        # Infer from any A/B tensor pair shape
        for k, v in header.items():
            if "lora_A" in k or "lora_down" in k:
                shape = v.get("shape") or []
                if shape:
                    rank = int(min(shape))
                    break

    return LoRAInfo(
        path=path,
        rank=rank,
        target="transformer",
        size_bytes=path.stat().st_size,
    )


@contextmanager
def applied_lora(pipe: Any, path: Path | str | None, strength: float) -> Iterator[None]:
    """Apply a LoRA to the pipeline's dit for the duration of the context.

    Reverts on exit (including exception path) so the cached GPU model is left clean.
    If ``path`` is ``None``, this is a no-op.

    Validates the LoRA file with :func:`sniff` before touching the pipeline so a bad
    file is rejected before any GPU work begins.
    """
    if path is None:
        yield
        return

    sniff(path)  # raises LoRAValidationError on bad input
    _apply_lora_impl(pipe, path, strength)
    try:
        yield
    finally:
        _revert_lora_impl(pipe)


def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None:
    """Apply a LoRA to ``pipe.dit`` using DiffSynth's ``load_lora`` (hotload mode).

    ``GeneralLoRALoader.convert_state_dict`` normalises CivitAI-style
    ``diffusion_model.*`` keys into the bare module-path keys DiffSynth's
    AutoWrappedLinear modules consume, so we don't need to remap ourselves.
    """
    pipe.load_lora(module=pipe.dit, lora_config=str(path), alpha=float(strength), verbose=0)


def _revert_lora_impl(pipe: Any) -> None:
    """Clear the hotloaded LoRA so the cached transformer is left clean."""
    pipe.clear_lora(verbose=0)