ACE-Music-Studio / tests /test_ace_pipeline_lazy.py
techfreakworm's picture
feat(ui): add advanced controls accordion — inference steps, cfg, infer method, seed, lm cot, schedule, metadata
c287b6a unverified
"""L2 tests for the ACEStepStudio wrapper — mocks the heavy acestep imports."""
from __future__ import annotations
import sys
from unittest.mock import MagicMock
import pytest
import ace_pipeline as ap
def test_get_pipeline_returns_singleton(monkeypatch):
monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
p1 = ap.get_pipeline()
p2 = ap.get_pipeline()
assert p1 is p2
assert isinstance(p1, ap.ACEStepStudio)
def test_studio_constructor_uses_detected_device(monkeypatch):
monkeypatch.setattr(ap, "detect_device", lambda: "mps")
studio = ap.ACEStepStudio()
assert studio.device == "mps"
assert studio.is_loaded is False # handlers are lazy
def test_studio_constructor_respects_env_overrides(monkeypatch):
monkeypatch.setenv("ACE_DIT_CONFIG", "custom-dit")
monkeypatch.setenv("ACE_LM_MODEL", "custom-lm")
monkeypatch.setattr(ap, "detect_device", lambda: "cpu")
studio = ap.ACEStepStudio()
assert studio._dit_config == "custom-dit"
assert studio._lm_model == "custom-lm"
def test_studio_ensure_loaded_constructs_both_handlers(monkeypatch):
fake_dit_cls = MagicMock(name="AceStepHandler")
fake_lm_cls = MagicMock(name="LLMHandler")
fake_dit = MagicMock()
fake_lm = MagicMock()
fake_dit_cls.return_value = fake_dit
fake_lm_cls.return_value = fake_lm
handler_mod = MagicMock()
handler_mod.AceStepHandler = fake_dit_cls
llm_mod = MagicMock()
llm_mod.LLMHandler = fake_lm_cls
monkeypatch.setitem(sys.modules, "acestep.handler", handler_mod)
monkeypatch.setitem(sys.modules, "acestep.llm_inference", llm_mod)
monkeypatch.setattr(ap, "detect_device", lambda: "mps")
studio = ap.ACEStepStudio()
studio._ensure_loaded()
fake_dit_cls.assert_called_once()
fake_lm_cls.assert_called_once()
fake_dit.initialize_service.assert_called_once()
fake_lm.initialize.assert_called_once()
assert fake_dit.initialize_service.call_args.kwargs["device"] == "mps"
assert fake_lm.initialize.call_args.kwargs["device"] == "mps"
assert fake_dit.initialize_service.call_args.kwargs["config_path"] == "acestep-v15-xl-sft"
assert fake_lm.initialize.call_args.kwargs["lm_model_path"] == "acestep-5Hz-lm-0.6B"
def _install_fake_inference(monkeypatch, success=True, audios=None, error=None):
"""Plant a fake ``acestep.inference`` module and return the spies."""
if audios is None:
audios = [{"path": "/tmp/x.wav"}]
fake_result = MagicMock(success=success, audios=audios, error=error)
fake_generate = MagicMock(return_value=fake_result)
captured = {"gp": {}, "gc": {}}
def fake_gp(**kw):
captured["gp"] = kw
return kw
def fake_gc(**kw):
captured["gc"] = kw
return kw
fake_inference = MagicMock()
fake_inference.generate_music = fake_generate
fake_inference.GenerationParams = MagicMock(side_effect=fake_gp)
fake_inference.GenerationConfig = MagicMock(side_effect=fake_gc)
monkeypatch.setitem(sys.modules, "acestep.inference", fake_inference)
return fake_generate, captured
def test_studio_generate_builds_params_and_calls_generate_music(monkeypatch, tmp_path):
out_wav = tmp_path / "out.wav"
out_wav.write_bytes(b"RIFF" + b"\0" * 100)
fake_generate, captured = _install_fake_inference(monkeypatch, audios=[{"path": str(out_wav)}])
studio = ap.ACEStepStudio()
studio._dit = MagicMock(name="dit")
studio._llm = MagicMock(name="llm")
result_path = studio.generate(
{
"prompt": "psytrance",
"lyrics": "[verse]",
"duration_s": 30,
"instrumental": False,
"seed": 42,
"loras": [],
# New advanced contract: ``inference_steps`` + ``guidance_scale``
# + ``infer_method`` replace the old ``steps`` + ``cfg`` keys.
# See ace_pipeline.ACEStepStudio.generate for the full schema.
"advanced": {
"inference_steps": 32,
"guidance_scale": 4.0,
"infer_method": "ode",
"bpm": 135,
},
"lm": {"thinking": False},
"dcw": {},
}
)
assert result_path == str(out_wav)
fake_generate.assert_called_once()
assert captured["gp"]["caption"] == "psytrance"
assert captured["gp"]["duration"] == 30
assert captured["gp"]["seed"] == 42
assert captured["gp"]["inference_steps"] == 32
assert captured["gp"]["guidance_scale"] == 4.0
assert captured["gp"]["infer_method"] == "ode"
assert captured["gp"]["bpm"] == 135
def test_studio_generate_raises_on_failure(monkeypatch):
_install_fake_inference(monkeypatch, success=False, audios=[], error="OOM")
studio = ap.ACEStepStudio()
studio._dit = MagicMock()
studio._llm = MagicMock()
with pytest.raises(RuntimeError, match="OOM"):
studio.generate(
{
"prompt": "p",
"lyrics": "",
"duration_s": 5,
"instrumental": True,
"seed": 1,
"advanced": {},
"lm": {},
"dcw": {},
}
)
def test_studio_generate_uses_instrumental_marker_when_lyrics_empty(monkeypatch):
_fake_generate, captured = _install_fake_inference(monkeypatch)
studio = ap.ACEStepStudio()
studio._dit = MagicMock()
studio._llm = MagicMock()
studio.generate(
{
"prompt": "drone",
"lyrics": "",
"duration_s": 5,
"instrumental": True,
"seed": 1,
"advanced": {},
"lm": {},
"dcw": {},
}
)
# Instrumental + empty lyrics → ACE-Step convention is "[Instrumental]"
assert captured["gp"]["lyrics"] == "[Instrumental]"
assert captured["gp"]["instrumental"] is True