Spaces:
Running on Zero
fix(lora): accept diffusion_model.* prefix and use pipe.load_lora hotload
Browse filesThe sniff allowlist only had transformer./dit./model.transformer., so any
CivitAI / Kohya LoRA whose keys started with diffusion_model.* (Toontastic,
DarkGhibly, etc.) was rejected before reaching the pipeline.
The bigger issue: _apply_lora_impl was calling diffsynth.utils.lora.merge_lora
(loras: List[Dict], alpha) as if it were a state_dict-to-module fuser. It isn't
— it's an image2lora list-merger — so even the LoRAs we did accept never
actually patched the DiT. Swapped to pipe.load_lora(module=pipe.dit, ...)
which routes through GeneralLoRALoader.convert_state_dict (strips the
diffusion_model. prefix natively) and hotloads into AutoWrappedLinear,
reverted at exit via pipe.clear_lora().
Verified live with Toon5_E10 (Z-Image-Turbo) and DarkGhiblyZBase (Z-Image
Base) on local MPS.
- lora.py +9 -21
- tests/test_lora.py +15 -0
|
@@ -10,7 +10,7 @@ from dataclasses import dataclass
|
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any
|
| 12 |
|
| 13 |
-
ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
|
| 14 |
|
| 15 |
|
| 16 |
class LoRAValidationError(ValueError):
|
|
@@ -97,27 +97,15 @@ def applied_lora(pipe: Any, path: Path | str | None, strength: float) -> Iterato
|
|
| 97 |
|
| 98 |
|
| 99 |
def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None:
|
| 100 |
-
"""Apply a LoRA to ``pipe.dit``
|
| 101 |
-
from diffsynth.utils.lora import merge_lora
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def _revert_lora_impl(pipe: Any) -> None:
|
| 107 |
-
"""Revert the most recent LoRA from ``pipe.dit``.
|
| 108 |
-
|
| 109 |
-
Tries DiffSynth's ``unmerge_lora`` first; falls back to re-fetching clean
|
| 110 |
-
weights from the model pool if unavailable.
|
| 111 |
"""
|
| 112 |
-
|
| 113 |
-
from diffsynth.utils.lora import unmerge_lora
|
| 114 |
|
| 115 |
-
unmerge_lora(pipe.dit)
|
| 116 |
-
return
|
| 117 |
-
except ImportError:
|
| 118 |
-
pass
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
pipe.dit = pipe.model_pool.fetch_model("z_image_dit", variant=variant)
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any
|
| 12 |
|
| 13 |
+
ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.", "diffusion_model.")
|
| 14 |
|
| 15 |
|
| 16 |
class LoRAValidationError(ValueError):
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None:
|
| 100 |
+
"""Apply a LoRA to ``pipe.dit`` using DiffSynth's ``load_lora`` (hotload mode).
|
|
|
|
| 101 |
|
| 102 |
+
``GeneralLoRALoader.convert_state_dict`` normalises CivitAI-style
|
| 103 |
+
``diffusion_model.*`` keys into the bare module-path keys DiffSynth's
|
| 104 |
+
AutoWrappedLinear modules consume, so we don't need to remap ourselves.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
"""
|
| 106 |
+
pipe.load_lora(module=pipe.dit, lora_config=str(path), alpha=float(strength), verbose=0)
|
|
|
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
def _revert_lora_impl(pipe: Any) -> None:
|
| 110 |
+
"""Clear the hotloaded LoRA so the cached transformer is left clean."""
|
| 111 |
+
pipe.clear_lora(verbose=0)
|
|
|
|
@@ -37,6 +37,21 @@ def test_sniff_rejects_non_safetensors(tmp_path):
|
|
| 37 |
assert "safetensors" in str(exc.value).lower()
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def test_sniff_rejects_non_zimage_keys(tmp_path):
|
| 41 |
p = tmp_path / "wrong.safetensors"
|
| 42 |
_write_safetensors(
|
|
|
|
| 37 |
assert "safetensors" in str(exc.value).lower()
|
| 38 |
|
| 39 |
|
| 40 |
+
def test_sniff_accepts_diffusion_model_prefix(tmp_path):
|
| 41 |
+
"""CivitAI / Kohya LoRAs prefix keys with ``diffusion_model.`` — must be accepted."""
|
| 42 |
+
p = tmp_path / "civitai.safetensors"
|
| 43 |
+
_write_safetensors(
|
| 44 |
+
p,
|
| 45 |
+
{
|
| 46 |
+
"diffusion_model.layers.0.adaLN_modulation.0.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
|
| 47 |
+
"diffusion_model.layers.0.adaLN_modulation.0.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
|
| 48 |
+
},
|
| 49 |
+
)
|
| 50 |
+
info = lora.sniff(p)
|
| 51 |
+
assert info.rank == 16
|
| 52 |
+
assert info.target == "transformer"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
def test_sniff_rejects_non_zimage_keys(tmp_path):
|
| 56 |
p = tmp_path / "wrong.safetensors"
|
| 57 |
_write_safetensors(
|