Spaces:
Running on Zero
Running on Zero
feat(lora): applied_lora ctx manager — validate, apply, revert on exit
Browse files- lora.py +49 -0
- tests/test_lora.py +58 -0
lora.py
CHANGED
|
@@ -3,8 +3,10 @@ from __future__ import annotations
|
|
| 3 |
|
| 4 |
import json
|
| 5 |
import struct
|
|
|
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from pathlib import Path
|
|
|
|
| 8 |
|
| 9 |
ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
|
| 10 |
|
|
@@ -68,3 +70,50 @@ def sniff(path: Path | str) -> LoRAInfo:
|
|
| 68 |
target="transformer",
|
| 69 |
size_bytes=path.stat().st_size,
|
| 70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import json
|
| 5 |
import struct
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Any, Iterator
|
| 10 |
|
| 11 |
ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
|
| 12 |
|
|
|
|
| 70 |
target="transformer",
|
| 71 |
size_bytes=path.stat().st_size,
|
| 72 |
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@contextmanager
|
| 76 |
+
def applied_lora(pipe: Any, path: Path | str | None, strength: float) -> Iterator[None]:
|
| 77 |
+
"""Apply a LoRA to the pipeline's dit for the duration of the context.
|
| 78 |
+
|
| 79 |
+
Reverts on exit (including exception path) so the cached GPU model is left clean.
|
| 80 |
+
If ``path`` is ``None``, this is a no-op.
|
| 81 |
+
|
| 82 |
+
Validates the LoRA file with :func:`sniff` before touching the pipeline so a bad
|
| 83 |
+
file is rejected before any GPU work begins.
|
| 84 |
+
"""
|
| 85 |
+
if path is None:
|
| 86 |
+
yield
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
sniff(path) # raises LoRAValidationError on bad input
|
| 90 |
+
_apply_lora_impl(pipe, path, strength)
|
| 91 |
+
try:
|
| 92 |
+
yield
|
| 93 |
+
finally:
|
| 94 |
+
_revert_lora_impl(pipe)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None:
|
| 98 |
+
"""Apply a LoRA to ``pipe.dit``. Imports DiffSynth lazily for testability."""
|
| 99 |
+
from diffsynth.utils.lora import merge_lora
|
| 100 |
+
merge_lora(pipe.dit, str(path), alpha=float(strength))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _revert_lora_impl(pipe: Any) -> None:
|
| 104 |
+
"""Revert the most recent LoRA from ``pipe.dit``.
|
| 105 |
+
|
| 106 |
+
Tries DiffSynth's ``unmerge_lora`` first; falls back to re-fetching clean
|
| 107 |
+
weights from the model pool if unavailable.
|
| 108 |
+
"""
|
| 109 |
+
try:
|
| 110 |
+
from diffsynth.utils.lora import unmerge_lora
|
| 111 |
+
unmerge_lora(pipe.dit)
|
| 112 |
+
return
|
| 113 |
+
except ImportError:
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
if hasattr(pipe, "model_pool"):
|
| 117 |
+
variant = getattr(pipe.dit, "_zis_variant", None)
|
| 118 |
+
if variant:
|
| 119 |
+
pipe.dit = pipe.model_pool.fetch_model("z_image_dit", variant=variant)
|
tests/test_lora.py
CHANGED
|
@@ -43,3 +43,61 @@ def test_sniff_rejects_non_zimage_keys(tmp_path):
|
|
| 43 |
lora.sniff(p)
|
| 44 |
msg = str(exc.value).lower()
|
| 45 |
assert "down_blocks" in msg or "unexpected" in msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
lora.sniff(p)
|
| 44 |
msg = str(exc.value).lower()
|
| 45 |
assert "down_blocks" in msg or "unexpected" in msg
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class _FakePipe:
|
| 49 |
+
"""Minimal stand-in for DiffSynth's ZImagePipeline.dit hook surface."""
|
| 50 |
+
def __init__(self):
|
| 51 |
+
self.applied = [] # list of (path, strength) tuples
|
| 52 |
+
self.reverted = []
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_applied_lora_calls_apply_then_revert(tmp_path, monkeypatch):
|
| 56 |
+
p = tmp_path / "ok.safetensors"
|
| 57 |
+
_write_safetensors(p, {
|
| 58 |
+
"transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [32, 3840]},
|
| 59 |
+
"transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 32]},
|
| 60 |
+
})
|
| 61 |
+
pipe = _FakePipe()
|
| 62 |
+
|
| 63 |
+
def fake_apply(pipe, path, strength):
|
| 64 |
+
pipe.applied.append((str(path), strength))
|
| 65 |
+
def fake_revert(pipe):
|
| 66 |
+
pipe.reverted.append(True)
|
| 67 |
+
monkeypatch.setattr(lora, "_apply_lora_impl", fake_apply)
|
| 68 |
+
monkeypatch.setattr(lora, "_revert_lora_impl", fake_revert)
|
| 69 |
+
|
| 70 |
+
with lora.applied_lora(pipe, p, strength=0.8):
|
| 71 |
+
assert pipe.applied == [(str(p), 0.8)]
|
| 72 |
+
assert pipe.reverted == []
|
| 73 |
+
|
| 74 |
+
assert pipe.reverted == [True]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_applied_lora_with_none_is_a_noop(tmp_path, monkeypatch):
|
| 78 |
+
pipe = _FakePipe()
|
| 79 |
+
sentinel = []
|
| 80 |
+
monkeypatch.setattr(lora, "_apply_lora_impl", lambda *a, **k: sentinel.append("apply"))
|
| 81 |
+
monkeypatch.setattr(lora, "_revert_lora_impl", lambda *a, **k: sentinel.append("revert"))
|
| 82 |
+
|
| 83 |
+
with lora.applied_lora(pipe, None, strength=0.0):
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
assert sentinel == []
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_applied_lora_reverts_on_exception(tmp_path, monkeypatch):
|
| 90 |
+
p = tmp_path / "ok.safetensors"
|
| 91 |
+
_write_safetensors(p, {
|
| 92 |
+
"transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
|
| 93 |
+
"transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
|
| 94 |
+
})
|
| 95 |
+
pipe = _FakePipe()
|
| 96 |
+
monkeypatch.setattr(lora, "_apply_lora_impl", lambda pipe, p, s: pipe.applied.append((p, s)))
|
| 97 |
+
monkeypatch.setattr(lora, "_revert_lora_impl", lambda pipe: pipe.reverted.append(True))
|
| 98 |
+
|
| 99 |
+
with pytest.raises(RuntimeError):
|
| 100 |
+
with lora.applied_lora(pipe, p, strength=1.0):
|
| 101 |
+
raise RuntimeError("inference failed mid-step")
|
| 102 |
+
|
| 103 |
+
assert pipe.reverted == [True], "must still revert on exception"
|