File size: 4,265 Bytes
bc2513c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5065c
 
 
 
 
 
 
 
bc2513c
 
 
 
 
 
 
 
 
 
 
 
 
 
2e18e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc2513c
 
9a5065c
 
 
 
 
 
bc2513c
 
 
 
8759971
 
 
 
9a5065c
8759971
9a5065c
8759971
 
 
 
 
9a5065c
 
 
 
 
 
 
8759971
 
 
 
9a5065c
8759971
 
9a5065c
8759971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5065c
 
 
 
 
 
 
8759971
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import struct
from pathlib import Path

import pytest

import lora


def _write_safetensors(path: Path, header: dict) -> None:
    """Minimal safetensors file: 8-byte LE header length + JSON header (no tensor data)."""
    h = json.dumps(header).encode("utf-8")
    path.write_bytes(struct.pack("<Q", len(h)) + h)


def test_sniff_valid_zimage_lora_returns_metadata(tmp_path):
    p = tmp_path / "ok.safetensors"
    _write_safetensors(
        p,
        {
            "transformer.layer1.lora_A.weight": {"dtype": "BF16", "shape": [64, 3840]},
            "transformer.layer1.lora_B.weight": {"dtype": "BF16", "shape": [3840, 64]},
            "__metadata__": {"rank": "64"},
        },
    )
    info = lora.sniff(p)
    assert info.rank == 64
    assert info.target == "transformer"
    assert info.size_bytes == p.stat().st_size


def test_sniff_rejects_non_safetensors(tmp_path):
    p = tmp_path / "bad.bin"
    p.write_bytes(b"this is not a safetensors file at all")
    with pytest.raises(lora.LoRAValidationError) as exc:
        lora.sniff(p)
    assert "safetensors" in str(exc.value).lower()


def test_sniff_accepts_diffusion_model_prefix(tmp_path):
    """CivitAI / Kohya LoRAs prefix keys with ``diffusion_model.`` — must be accepted."""
    p = tmp_path / "civitai.safetensors"
    _write_safetensors(
        p,
        {
            "diffusion_model.layers.0.adaLN_modulation.0.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
            "diffusion_model.layers.0.adaLN_modulation.0.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
        },
    )
    info = lora.sniff(p)
    assert info.rank == 16
    assert info.target == "transformer"


def test_sniff_rejects_non_zimage_keys(tmp_path):
    p = tmp_path / "wrong.safetensors"
    _write_safetensors(
        p,
        {
            "down_blocks.0.weight": {"dtype": "F32", "shape": [320, 320]},
        },
    )
    with pytest.raises(lora.LoRAValidationError) as exc:
        lora.sniff(p)
    msg = str(exc.value).lower()
    assert "down_blocks" in msg or "unexpected" in msg


class _FakePipe:
    """Minimal stand-in for DiffSynth's ZImagePipeline.dit hook surface."""

    def __init__(self):
        self.applied = []  # list of (path, strength) tuples
        self.reverted = []


def test_applied_lora_calls_apply_then_revert(tmp_path, monkeypatch):
    p = tmp_path / "ok.safetensors"
    _write_safetensors(
        p,
        {
            "transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [32, 3840]},
            "transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 32]},
        },
    )
    pipe = _FakePipe()

    def fake_apply(pipe, path, strength):
        pipe.applied.append((str(path), strength))

    def fake_revert(pipe):
        pipe.reverted.append(True)

    monkeypatch.setattr(lora, "_apply_lora_impl", fake_apply)
    monkeypatch.setattr(lora, "_revert_lora_impl", fake_revert)

    with lora.applied_lora(pipe, p, strength=0.8):
        assert pipe.applied == [(str(p), 0.8)]
        assert pipe.reverted == []

    assert pipe.reverted == [True]


def test_applied_lora_with_none_is_a_noop(tmp_path, monkeypatch):
    pipe = _FakePipe()
    sentinel = []
    monkeypatch.setattr(lora, "_apply_lora_impl", lambda *a, **k: sentinel.append("apply"))
    monkeypatch.setattr(lora, "_revert_lora_impl", lambda *a, **k: sentinel.append("revert"))

    with lora.applied_lora(pipe, None, strength=0.0):
        pass

    assert sentinel == []


def test_applied_lora_reverts_on_exception(tmp_path, monkeypatch):
    p = tmp_path / "ok.safetensors"
    _write_safetensors(
        p,
        {
            "transformer.x.lora_A.weight": {"dtype": "BF16", "shape": [16, 3840]},
            "transformer.x.lora_B.weight": {"dtype": "BF16", "shape": [3840, 16]},
        },
    )
    pipe = _FakePipe()
    monkeypatch.setattr(lora, "_apply_lora_impl", lambda pipe, p, s: pipe.applied.append((p, s)))
    monkeypatch.setattr(lora, "_revert_lora_impl", lambda pipe: pipe.reverted.append(True))

    with pytest.raises(RuntimeError):
        with lora.applied_lora(pipe, p, strength=1.0):
            raise RuntimeError("inference failed mid-step")

    assert pipe.reverted == [True], "must still revert on exception"