techfreakworm commited on
Commit
2e18e13
·
unverified ·
1 Parent(s): 9514256

fix(lora): accept diffusion_model.* prefix and use pipe.load_lora hotload

Browse files

The sniff allowlist only had transformer./dit./model.transformer., so any
CivitAI / Kohya LoRA whose keys started with diffusion_model.* (Toontastic,
DarkGhibly, etc.) was rejected before reaching the pipeline.

The bigger issue: _apply_lora_impl was calling diffsynth.utils.lora.merge_lora
(loras: List[Dict], alpha) as if it were a state_dict-to-module fuser. It isn't
— it's an image2lora list-merger — so even the LoRAs we did accept never
actually patched the DiT. Swapped to pipe.load_lora(module=pipe.dit, ...)
which routes through GeneralLoRALoader.convert_state_dict (strips the
diffusion_model. prefix natively) and hotloads into AutoWrappedLinear,
reverted at exit via pipe.clear_lora().

Verified live with Toon5_E10 (Z-Image-Turbo) and DarkGhiblyZBase (Z-Image
Base) on local MPS.

Files changed (2) hide show
  1. lora.py +9 -21
  2. tests/test_lora.py +15 -0
lora.py CHANGED
@@ -10,7 +10,7 @@ from dataclasses import dataclass
10
  from pathlib import Path
11
  from typing import Any
12
 
13
- ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
14
 
15
 
16
  class LoRAValidationError(ValueError):
@@ -97,27 +97,15 @@ def applied_lora(pipe: Any, path: Path | str | None, strength: float) -> Iterato
97
 
98
 
99
  def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None:
100
- """Apply a LoRA to ``pipe.dit``. Imports DiffSynth lazily for testability."""
101
- from diffsynth.utils.lora import merge_lora
102
 
103
- merge_lora(pipe.dit, str(path), alpha=float(strength))
104
-
105
-
106
- def _revert_lora_impl(pipe: Any) -> None:
107
- """Revert the most recent LoRA from ``pipe.dit``.
108
-
109
- Tries DiffSynth's ``unmerge_lora`` first; falls back to re-fetching clean
110
- weights from the model pool if unavailable.
111
  """
112
- try:
113
- from diffsynth.utils.lora import unmerge_lora
114
 
115
- unmerge_lora(pipe.dit)
116
- return
117
- except ImportError:
118
- pass
119
 
120
- if hasattr(pipe, "model_pool"):
121
- variant = getattr(pipe.dit, "_zis_variant", None)
122
- if variant:
123
- pipe.dit = pipe.model_pool.fetch_model("z_image_dit", variant=variant)
 
10
  from pathlib import Path
11
  from typing import Any
12
 
13
+ ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.", "diffusion_model.")
14
 
15
 
16
  class LoRAValidationError(ValueError):
 
97
 
98
 
99
  def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None:
100
+ """Apply a LoRA to ``pipe.dit`` using DiffSynth's ``load_lora`` (hotload mode).
 
101
 
102
+ ``GeneralLoRALoader.convert_state_dict`` normalises CivitAI-style
103
+ ``diffusion_model.*`` keys into the bare module-path keys DiffSynth's
104
+ AutoWrappedLinear modules consume, so we don't need to remap ourselves.
 
 
 
 
 
105
  """
106
+ pipe.load_lora(module=pipe.dit, lora_config=str(path), alpha=float(strength), verbose=0)
 
107
 
 
 
 
 
108
 
109
+ def _revert_lora_impl(pipe: Any) -> None:
110
+ """Clear the hotloaded LoRA so the cached transformer is left clean."""
111
+ pipe.clear_lora(verbose=0)
 
tests/test_lora.py CHANGED
@@ -37,6 +37,21 @@ def test_sniff_rejects_non_safetensors(tmp_path):
37
  assert "safetensors" in str(exc.value).lower()
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def test_sniff_rejects_non_zimage_keys(tmp_path):
41
  p = tmp_path / "wrong.safetensors"
42
  _write_safetensors(
 
37
  assert "safetensors" in str(exc.value).lower()
38
 
39
 
40
+ def test_sniff_accepts_diffusion_model_prefix(tmp_path):
41
+ """CivitAI / Kohya LoRAs prefix keys with ``diffusion_model.`` — must be accepted."""
42
+ p = tmp_path / "civitai.safetensors"
43
+ _write_safetensors(
44
+ p,
45
+ {
46
+ "diffusion_model.layers.0.adaLN_modulation.0.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
47
+ "diffusion_model.layers.0.adaLN_modulation.0.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
48
+ },
49
+ )
50
+ info = lora.sniff(p)
51
+ assert info.rank == 16
52
+ assert info.target == "transformer"
53
+
54
+
55
  def test_sniff_rejects_non_zimage_keys(tmp_path):
56
  p = tmp_path / "wrong.safetensors"
57
  _write_safetensors(