File size: 6,726 Bytes
c9f8dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e701df3
c9f8dd1
 
 
 
e701df3
 
c9f8dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321117e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e701df3
 
 
 
 
 
 
 
901e296
 
 
 
e701df3
 
901e296
 
 
 
 
 
 
 
 
e701df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""LoRA stack: sniff/validate user-uploaded .safetensors files and
manage which one is active on the ACE-Step DiT handler.

Single-LoRA semantics
---------------------
The Apple-Silicon ACE-Step fork's AceStepHandler exposes a one-LoRA-
at-a-time API (load_lora / unload_lora / set_use_lora / set_lora_scale),
not the multi-adapter PEFT pattern the plan's Task D3 originally
described. ``apply_stack(pipe, stack)`` therefore supports:

- empty stack -> ``unload_lora`` + ``set_use_lora(False)``
- single-entry stack -> ``load_lora(path)`` + ``set_lora_scale(scale)``
  + ``set_use_lora(True)``
- multi-entry stack -> use only the first, log a warning

If the upstream pipeline ever exposes multi-adapter support, this
function can be extended without changing the wrapper's call sites.
"""

from __future__ import annotations

import json
import logging
import struct
from dataclasses import dataclass
from pathlib import Path

_log = logging.getLogger("ams.lora")

# Expected DiT module suffixes for ACE-Step 1.5 XL SFT.
# Match against `*.to_q.lora_A.weight`, etc.
_EXPECTED_MODULES = {"to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"}
_MAX_FILE_BYTES = 500 * 1024 * 1024  # 500 MB cap
_MAX_RANK = 256


class LoRAValidationError(ValueError):
    """Raised when a LoRA file fails validation."""


@dataclass
class LoRAInfo:
    path: Path
    compatible: bool
    rank: int
    alpha: int | None
    target_modules: set[str]
    diagnostic: str
    file_size: int


def sniff(path: Path | str) -> LoRAInfo:
    """Read the safetensors header; do not materialise tensors."""
    path = Path(path)
    if not path.exists():
        raise LoRAValidationError(f"File not found: {path}")

    file_size = path.stat().st_size
    if file_size > _MAX_FILE_BYTES:
        raise LoRAValidationError(
            f"File too large ({file_size / 1e6:.0f} MB > {_MAX_FILE_BYTES / 1e6:.0f} MB cap)."
        )

    with open(path, "rb") as f:
        header_len_bytes = f.read(8)
        if len(header_len_bytes) < 8:
            raise LoRAValidationError("Not a valid .safetensors file (truncated)")
        header_len = struct.unpack("<Q", header_len_bytes)[0]
        if header_len <= 0 or header_len > 10 * 1024 * 1024:
            raise LoRAValidationError(f"Unreasonable header length: {header_len}")
        header_bytes = f.read(header_len)

    try:
        header = json.loads(header_bytes)
    except json.JSONDecodeError as e:
        raise LoRAValidationError(f"Invalid header JSON: {e}") from e

    target_modules: set[str] = set()
    rank = 0
    alpha = None
    has_ace_prefix = False

    for k, v in header.items():
        if k == "__metadata__":
            if isinstance(v, dict):
                if "lora_alpha" in v:
                    try:
                        alpha = int(v["lora_alpha"])
                    except (TypeError, ValueError):
                        pass
            continue
        if not isinstance(v, dict) or "shape" not in v:
            continue
        # ACE-Step DiT keys start with "transformer." (the diffusers DiT prefix).
        # SDXL UNet LoRAs start with "unet." — reject those even though the
        # inner attention layer names overlap (`.to_q.lora_A.weight`).
        if k.startswith("transformer.") or k.startswith("transformer_blocks."):
            has_ace_prefix = True
        # Extract module suffix from things like "transformer.blocks.0.attn.to_q.lora_A.weight"
        for suffix in _EXPECTED_MODULES:
            if f".{suffix}.lora_A.weight" in k or f".{suffix}.lora_B.weight" in k:
                target_modules.add(suffix)
                if "lora_A.weight" in k:
                    rank = max(rank, int(v["shape"][0]))
                break

    compatible = has_ace_prefix and bool(target_modules) and (rank > 0) and (rank <= _MAX_RANK)
    diagnostic = (
        "OK"
        if compatible
        else (
            f"Expected ACE-Step DiT modules ({sorted(_EXPECTED_MODULES)}), got modules in: "
            f"{sorted(set(header.keys()) - {'__metadata__'})[:3]}…"
        )
    )

    return LoRAInfo(
        path=path,
        compatible=compatible,
        rank=rank,
        alpha=alpha,
        target_modules=target_modules,
        diagnostic=diagnostic,
        file_size=file_size,
    )


_PRESETS_PATH = Path(__file__).resolve().parent / "presets" / "manifest.json"


def load_presets() -> list[dict]:
    """Load the bundled LoRA preset manifest."""
    return json.loads(_PRESETS_PATH.read_text())


def download_preset(name: str) -> Path:
    """Download a preset LoRA from HF if not already cached.

    Returns the local path on success. Raises LoRAValidationError if the
    preset name is unknown OR the HF download fails (network, 404, etc.).
    """
    from huggingface_hub import hf_hub_download
    from huggingface_hub.utils import HfHubHTTPError

    for p in load_presets():
        if p["name"] == name:
            try:
                local = hf_hub_download(repo_id=p["hf_id"], filename=p["filename"])
                return Path(local)
            except HfHubHTTPError as e:
                raise LoRAValidationError(
                    f"Could not download preset {name!r} from {p['hf_id']!r}: {e}"
                ) from e
    raise LoRAValidationError(f"Unknown preset: {name}")


def apply_stack(pipe, stack: list[dict]) -> None:
    """Activate the given LoRA stack on the pipeline's DiT handler.

    Apple-Silicon fork supports only one active LoRA at a time
    (see module docstring). Behaviour:

    - ``stack == []``: disable + unload the current LoRA (no-op if the
      pipe hasn't been loaded yet — nothing to unload).
    - ``len(stack) == 1``: load + set scale + enable. Forces a pipeline
      load if it hasn't happened yet, since the LoRA targets the DiT.
    - ``len(stack) >= 2``: load the first, warn that the rest is ignored.
    """
    # Empty stack + cold pipe: no DiT to touch, nothing to unload.
    if not stack and pipe._dit is None:
        return

    # Non-empty stack but cold pipe: force the lazy-load so we have a DiT
    # to attach the LoRA to.
    if stack and pipe._dit is None:
        pipe._ensure_loaded()

    dit = pipe._dit  # internal AceStepHandler reference
    if not stack:
        dit.unload_lora()
        dit.set_use_lora(False)
        return

    if len(stack) > 1:
        _log.warning(
            "apply_stack received %d LoRAs but only one is supported by "
            "the apple-silicon ACE-Step fork; activating %r and ignoring the rest.",
            len(stack),
            stack[0]["name"],
        )

    first = stack[0]
    dit.load_lora(first["path"])
    dit.set_lora_scale(float(first["scale"]))
    dit.set_use_lora(True)