z-image-studio / tests /test_lora.py
techfreakworm's picture
fix(lora): accept diffusion_model.* prefix and use pipe.load_lora hotload
2e18e13 unverified
import json
import struct
from pathlib import Path
import pytest
import lora
def _write_safetensors(path: Path, header: dict) -> None:
"""Minimal safetensors file: 8-byte LE header length + JSON header (no tensor data)."""
h = json.dumps(header).encode("utf-8")
path.write_bytes(struct.pack("<Q", len(h)) + h)
def test_sniff_valid_zimage_lora_returns_metadata(tmp_path):
p = tmp_path / "ok.safetensors"
_write_safetensors(
p,
{
"transformer.layer1.lora_A.weight": {"dtype": "BF16", "shape": [64, 3840]},
"transformer.layer1.lora_B.weight": {"dtype": "BF16", "shape": [3840, 64]},
"__metadata__": {"rank": "64"},
},
)
info = lora.sniff(p)
assert info.rank == 64
assert info.target == "transformer"
assert info.size_bytes == p.stat().st_size
def test_sniff_rejects_non_safetensors(tmp_path):
p = tmp_path / "bad.bin"
p.write_bytes(b"this is not a safetensors file at all")
with pytest.raises(lora.LoRAValidationError) as exc:
lora.sniff(p)
assert "safetensors" in str(exc.value).lower()
def test_sniff_accepts_diffusion_model_prefix(tmp_path):
"""CivitAI / Kohya LoRAs prefix keys with ``diffusion_model.`` — must be accepted."""
p = tmp_path / "civitai.safetensors"
_write_safetensors(
p,
{
"diffusion_model.layers.0.adaLN_modulation.0.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
"diffusion_model.layers.0.adaLN_modulation.0.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
},
)
info = lora.sniff(p)
assert info.rank == 16
assert info.target == "transformer"
def test_sniff_rejects_non_zimage_keys(tmp_path):
p = tmp_path / "wrong.safetensors"
_write_safetensors(
p,
{
"down_blocks.0.weight": {"dtype": "F32", "shape": [320, 320]},
},
)
with pytest.raises(lora.LoRAValidationError) as exc:
lora.sniff(p)
msg = str(exc.value).lower()
assert "down_blocks" in msg or "unexpected" in msg
class _FakePipe:
"""Minimal stand-in for DiffSynth's ZImagePipeline.dit hook surface."""
def __init__(self):
self.applied = [] # list of (path, strength) tuples
self.reverted = []
def test_applied_lora_calls_apply_then_revert(tmp_path, monkeypatch):
p = tmp_path / "ok.safetensors"
_write_safetensors(
p,
{
"transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [32, 3840]},
"transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 32]},
},
)
pipe = _FakePipe()
def fake_apply(pipe, path, strength):
pipe.applied.append((str(path), strength))
def fake_revert(pipe):
pipe.reverted.append(True)
monkeypatch.setattr(lora, "_apply_lora_impl", fake_apply)
monkeypatch.setattr(lora, "_revert_lora_impl", fake_revert)
with lora.applied_lora(pipe, p, strength=0.8):
assert pipe.applied == [(str(p), 0.8)]
assert pipe.reverted == []
assert pipe.reverted == [True]
def test_applied_lora_with_none_is_a_noop(tmp_path, monkeypatch):
pipe = _FakePipe()
sentinel = []
monkeypatch.setattr(lora, "_apply_lora_impl", lambda *a, **k: sentinel.append("apply"))
monkeypatch.setattr(lora, "_revert_lora_impl", lambda *a, **k: sentinel.append("revert"))
with lora.applied_lora(pipe, None, strength=0.0):
pass
assert sentinel == []
def test_applied_lora_reverts_on_exception(tmp_path, monkeypatch):
p = tmp_path / "ok.safetensors"
_write_safetensors(
p,
{
"transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
"transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
},
)
pipe = _FakePipe()
monkeypatch.setattr(lora, "_apply_lora_impl", lambda pipe, p, s: pipe.applied.append((p, s)))
monkeypatch.setattr(lora, "_revert_lora_impl", lambda pipe: pipe.reverted.append(True))
with pytest.raises(RuntimeError):
with lora.applied_lora(pipe, p, strength=1.0):
raise RuntimeError("inference failed mid-step")
assert pipe.reverted == [True], "must still revert on exception"