techfreakworm commited on
Commit
8759971
·
unverified ·
1 Parent(s): bc2513c

feat(lora): applied_lora ctx manager — validate, apply, revert on exit

Browse files
Files changed (2) hide show
  1. lora.py +49 -0
  2. tests/test_lora.py +58 -0
lora.py CHANGED
@@ -3,8 +3,10 @@ from __future__ import annotations
3
 
4
  import json
5
  import struct
 
6
  from dataclasses import dataclass
7
  from pathlib import Path
 
8
 
9
  ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
10
 
@@ -68,3 +70,50 @@ def sniff(path: Path | str) -> LoRAInfo:
68
  target="transformer",
69
  size_bytes=path.stat().st_size,
70
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import json
5
  import struct
6
+ from contextlib import contextmanager
7
  from dataclasses import dataclass
8
  from pathlib import Path
9
+ from typing import Any, Iterator
10
 
11
  ZIMAGE_LORA_PREFIXES = ("transformer.", "dit.", "model.transformer.")
12
 
 
70
  target="transformer",
71
  size_bytes=path.stat().st_size,
72
  )
73
+
74
+
75
+ @contextmanager
76
+ def applied_lora(pipe: Any, path: Path | str | None, strength: float) -> Iterator[None]:
77
+ """Apply a LoRA to the pipeline's dit for the duration of the context.
78
+
79
+ Reverts on exit (including exception path) so the cached GPU model is left clean.
80
+ If ``path`` is ``None``, this is a no-op.
81
+
82
+ Validates the LoRA file with :func:`sniff` before touching the pipeline so a bad
83
+ file is rejected before any GPU work begins.
84
+ """
85
+ if path is None:
86
+ yield
87
+ return
88
+
89
+ sniff(path) # raises LoRAValidationError on bad input
90
+ _apply_lora_impl(pipe, path, strength)
91
+ try:
92
+ yield
93
+ finally:
94
+ _revert_lora_impl(pipe)
95
+
96
+
97
+ def _apply_lora_impl(pipe: Any, path: Path | str, strength: float) -> None:
98
+ """Apply a LoRA to ``pipe.dit``. Imports DiffSynth lazily for testability."""
99
+ from diffsynth.utils.lora import merge_lora
100
+ merge_lora(pipe.dit, str(path), alpha=float(strength))
101
+
102
+
103
+ def _revert_lora_impl(pipe: Any) -> None:
104
+ """Revert the most recent LoRA from ``pipe.dit``.
105
+
106
+ Tries DiffSynth's ``unmerge_lora`` first; falls back to re-fetching clean
107
+ weights from the model pool if unavailable.
108
+ """
109
+ try:
110
+ from diffsynth.utils.lora import unmerge_lora
111
+ unmerge_lora(pipe.dit)
112
+ return
113
+ except ImportError:
114
+ pass
115
+
116
+ if hasattr(pipe, "model_pool"):
117
+ variant = getattr(pipe.dit, "_zis_variant", None)
118
+ if variant:
119
+ pipe.dit = pipe.model_pool.fetch_model("z_image_dit", variant=variant)
tests/test_lora.py CHANGED
@@ -43,3 +43,61 @@ def test_sniff_rejects_non_zimage_keys(tmp_path):
43
  lora.sniff(p)
44
  msg = str(exc.value).lower()
45
  assert "down_blocks" in msg or "unexpected" in msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  lora.sniff(p)
44
  msg = str(exc.value).lower()
45
  assert "down_blocks" in msg or "unexpected" in msg
46
+
47
+
48
+ class _FakePipe:
49
+ """Minimal stand-in for DiffSynth's ZImagePipeline.dit hook surface."""
50
+ def __init__(self):
51
+ self.applied = [] # list of (path, strength) tuples
52
+ self.reverted = []
53
+
54
+
55
+ def test_applied_lora_calls_apply_then_revert(tmp_path, monkeypatch):
56
+ p = tmp_path / "ok.safetensors"
57
+ _write_safetensors(p, {
58
+ "transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [32, 3840]},
59
+ "transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 32]},
60
+ })
61
+ pipe = _FakePipe()
62
+
63
+ def fake_apply(pipe, path, strength):
64
+ pipe.applied.append((str(path), strength))
65
+ def fake_revert(pipe):
66
+ pipe.reverted.append(True)
67
+ monkeypatch.setattr(lora, "_apply_lora_impl", fake_apply)
68
+ monkeypatch.setattr(lora, "_revert_lora_impl", fake_revert)
69
+
70
+ with lora.applied_lora(pipe, p, strength=0.8):
71
+ assert pipe.applied == [(str(p), 0.8)]
72
+ assert pipe.reverted == []
73
+
74
+ assert pipe.reverted == [True]
75
+
76
+
77
+ def test_applied_lora_with_none_is_a_noop(tmp_path, monkeypatch):
78
+ pipe = _FakePipe()
79
+ sentinel = []
80
+ monkeypatch.setattr(lora, "_apply_lora_impl", lambda *a, **k: sentinel.append("apply"))
81
+ monkeypatch.setattr(lora, "_revert_lora_impl", lambda *a, **k: sentinel.append("revert"))
82
+
83
+ with lora.applied_lora(pipe, None, strength=0.0):
84
+ pass
85
+
86
+ assert sentinel == []
87
+
88
+
89
+ def test_applied_lora_reverts_on_exception(tmp_path, monkeypatch):
90
+ p = tmp_path / "ok.safetensors"
91
+ _write_safetensors(p, {
92
+ "transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
93
+ "transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
94
+ })
95
+ pipe = _FakePipe()
96
+ monkeypatch.setattr(lora, "_apply_lora_impl", lambda pipe, p, s: pipe.applied.append((p, s)))
97
+ monkeypatch.setattr(lora, "_revert_lora_impl", lambda pipe: pipe.reverted.append(True))
98
+
99
+ with pytest.raises(RuntimeError):
100
+ with lora.applied_lora(pipe, p, strength=1.0):
101
+ raise RuntimeError("inference failed mid-step")
102
+
103
+ assert pipe.reverted == [True], "must still revert on exception"