File size: 4,412 Bytes
99375d0
dfa2ff6
 
 
 
 
26dc3a4
 
dfa2ff6
 
 
99375d0
 
dfa2ff6
 
 
99375d0
 
dfa2ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99375d0
 
 
 
 
dfa2ff6
 
 
99375d0
 
 
 
dfa2ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99375d0
 
 
96012ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26dc3a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""L2 tests for backend.dispatch — pipeline is mocked at the wrapper boundary."""

from __future__ import annotations

from unittest.mock import MagicMock

import pytest

import backend as be


def test_dispatch_generate_calls_pipeline_generate(monkeypatch, tmp_path):
    """Backend should call ``pipe.generate(params)`` and return its path."""
    fake_out = tmp_path / "out.wav"
    fake_out.write_bytes(b"RIFF" + b"\0" * 1000)

    fake_pipe = MagicMock()
    fake_pipe.generate.return_value = str(fake_out)
    monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)

    b = be.ACEStepStudioBackend()
    out_path, meta = b.dispatch(
        mode="generate",
        params={
            "prompt": "psytrance",
            "lyrics": "[verse]",
            "duration_s": 10,
            "instrumental": False,
            "seed": 42,
            "loras": [],
            "advanced": {},
            "lm": {},
            "dcw": {},
        },
    )

    assert out_path == str(fake_out)
    assert meta["mode"] == "generate"
    assert meta["seed"] == 42
    fake_pipe.generate.assert_called_once()
    # The full params dict is forwarded to pipe.generate
    sent_params = fake_pipe.generate.call_args.args[0]
    assert sent_params["prompt"] == "psytrance"
    assert sent_params["seed"] == 42


def test_dispatch_random_seed_if_zero(monkeypatch, tmp_path):
    out = tmp_path / "x.wav"
    out.write_bytes(b"RIFF")
    fake_pipe = MagicMock()
    fake_pipe.generate.return_value = str(out)
    monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)

    b = be.ACEStepStudioBackend()
    _, meta = b.dispatch(
        mode="generate",
        params={
            "prompt": "p",
            "lyrics": "",
            "duration_s": 5,
            "instrumental": False,
            "seed": 0,
            "loras": [],
            "advanced": {},
            "lm": {},
            "dcw": {},
        },
    )

    assert 1 <= meta["seed"] <= 2_147_483_647
    # The seed-resolved value is the one forwarded to the wrapper
    sent_params = fake_pipe.generate.call_args.args[0]
    assert sent_params["seed"] == meta["seed"]


def test_dispatch_applies_lora_stack(monkeypatch, tmp_path):
    fake_pipe = MagicMock()
    fake_pipe.generate.return_value = str(tmp_path / "x.wav")
    (tmp_path / "x.wav").write_bytes(b"RIFF")
    monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)

    apply_mock = MagicMock()
    monkeypatch.setattr("lora_stack.apply_stack", apply_mock)

    b = be.ACEStepStudioBackend()
    stack = [{"name": "RapMachine", "scale": 0.85, "path": "/x.safetensors", "sha256": "a" * 64}]
    b.dispatch(
        mode="generate",
        params={
            "prompt": "p",
            "lyrics": "",
            "duration_s": 5,
            "instrumental": False,
            "seed": 1,
            "loras": stack,
            "advanced": {},
            "lm": {},
            "dcw": {},
        },
    )

    apply_mock.assert_called_once_with(fake_pipe, stack)


@pytest.mark.parametrize(
    "mode,extra",
    [
        ("cover", {"ref_audio": "/tmp/ref.wav", "audio_cover_strength": 0.9}),
        ("extend", {"seed_audio": "/tmp/seed.wav", "extra_duration_s": 60}),
        (
            "edit",
            {
                "source_audio": "/tmp/src.wav",
                "segment_start_s": 50.0,
                "segment_end_s": 90.0,
                "sub_mode": "repaint",
            },
        ),
    ],
)
def test_dispatch_forwards_mode_to_pipe_generate(monkeypatch, tmp_path, mode, extra):
    fake_pipe = MagicMock()
    fake_pipe.generate.return_value = str(tmp_path / "x.wav")
    (tmp_path / "x.wav").write_bytes(b"RIFF")
    monkeypatch.setattr("ace_pipeline.get_pipeline", lambda: fake_pipe)
    monkeypatch.setattr("lora_stack.apply_stack", MagicMock())

    b = be.ACEStepStudioBackend()
    params = {
        "prompt": "p",
        "lyrics": "",
        "duration_s": 10,
        "instrumental": True,
        "seed": 42,
        "loras": [],
        "advanced": {},
        "lm": {},
        "dcw": {},
        **extra,
    }
    b.dispatch(mode=mode, params=params)

    fake_pipe.generate.assert_called_once()
    sent_params = fake_pipe.generate.call_args.args[0]
    assert sent_params["mode"] == mode
    # Mode-specific keys propagate to pipe.generate
    for k, v in extra.items():
        assert sent_params[k] == v