File size: 14,370 Bytes
46d16df
 
99375d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46d16df
0ea6ca2
46d16df
 
eb3bcb4
99375d0
 
 
 
6ad959d
99375d0
 
 
eb3bcb4
46d16df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99375d0
 
 
46d16df
 
 
 
 
 
 
 
 
 
 
eb3bcb4
 
99375d0
 
eb3bcb4
99375d0
 
 
eb3bcb4
99375d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26dc3a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99375d0
 
 
 
 
 
 
 
 
 
 
26dc3a4
99375d0
 
 
26dc3a4
 
 
99375d0
 
 
 
26dc3a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c287b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99375d0
26dc3a4
 
99375d0
 
26dc3a4
99375d0
c287b6a
 
 
 
99375d0
 
 
 
 
 
 
26dc3a4
 
 
 
 
 
c287b6a
 
 
 
99375d0
 
 
 
c287b6a
 
 
 
99375d0
 
 
 
 
 
 
 
 
6ad959d
 
 
 
 
 
 
 
 
 
 
 
99375d0
 
 
 
 
 
6ad959d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99375d0
 
 
 
 
 
 
 
 
 
 
eb3bcb4
 
99375d0
eb3bcb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""ACE-Step pipeline lifecycle: device autodetect, lazy load, cache mirror.

The installed ``acestep`` package (apple-silicon fork on Mac, upstream on
CUDA) does NOT expose a single ``ACEStepPipeline.from_pretrained`` entry
point. The real API is a split-handler pattern:

  from acestep.handler import AceStepHandler           # DiT side
  from acestep.llm_inference import LLMHandler         # 5Hz LM planner
  from acestep.inference import (
      GenerationParams, GenerationConfig, generate_music,
  )

  dit = AceStepHandler()
  dit.initialize_service(project_root=..., config_path="acestep-v15-xl-sft",
                         device="mps")
  lm = LLMHandler()
  lm.initialize(checkpoint_dir=..., lm_model_path="acestep-5Hz-lm-0.6B",
                backend="vllm",      # auto-routes to mlx on mps
                device="mps")
  params = GenerationParams(caption=..., lyrics=..., duration=..., seed=...)
  cfg = GenerationConfig(batch_size=1, audio_format="wav")
  result = generate_music(dit, lm, params, cfg)
  # result.audios[0]["path"] is the WAV file

To keep ``backend.py`` and ``modes.py`` clean, this module exposes a
single ``ACEStepStudio`` wrapper that owns both handlers and exposes a
``generate(params: dict) -> str`` method returning the audio path.
``get_pipeline()`` returns the lazy singleton wrapper.

Checkpoints live under ``{project_root}/checkpoints/{config_path}/``.
On Mac with the apple-silicon fork, the fork auto-downloads from
HuggingFace if a checkpoint is missing, but in practice we pre-download
via ``hf download`` before the first inference call to avoid pytest
timeouts.
"""

from __future__ import annotations

import os
from pathlib import Path

_REPO_ROOT = Path(__file__).resolve().parent
_CHECKPOINTS_DIR = _REPO_ROOT / "checkpoints"
_OUTPUT_DIR = _REPO_ROOT / "output"

_DEFAULT_DIT_CONFIG = "acestep-v15-xl-sft"
_DEFAULT_LM_MODEL = "acestep-5Hz-lm-0.6B"


def detect_device() -> str:
    """Returns 'cuda', 'mps', or 'cpu' in priority order."""
    try:
        import torch  # local import: keep module import cheap for CI
    except ImportError:
        return "cpu"

    if torch.cuda.is_available():
        return "cuda"
    # macOS: torch.backends.mps appeared in 2.0; guard for the rare absence
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def vram_limit_for(device: str) -> int | None:
    """Returns a VRAM cap in bytes for CUDA, None otherwise.

    ``torch.mps`` has no ``mem_get_info`` — calling DiffSynth-style
    free-VRAM gates with a numeric limit would crash on MPS. Returning
    None lets the pipeline short-circuit those checks.
    """
    if device != "cuda":
        return None
    try:
        import torch

        free, _total = torch.cuda.mem_get_info()
        # Leave 2 GiB headroom for activations
        return max(0, free - 2 * 1024**3)
    except Exception:
        return None


class ACEStepStudio:
    """Wrapper around the apple-silicon fork's split-handler API.

    Owns one ``AceStepHandler`` (DiT) and one ``LLMHandler`` (5Hz LM
    planner). Both are lazy-loaded on the first ``generate(...)`` call.
    """

    def __init__(
        self,
        dit_config: str | None = None,
        lm_model: str | None = None,
        device: str | None = None,
    ) -> None:
        self._dit = None
        self._llm = None
        self._dit_config = dit_config or os.environ.get("ACE_DIT_CONFIG", _DEFAULT_DIT_CONFIG)
        self._lm_model = lm_model or os.environ.get("ACE_LM_MODEL", _DEFAULT_LM_MODEL)
        self._device = device or detect_device()

    @property
    def device(self) -> str:
        return self._device

    @property
    def is_loaded(self) -> bool:
        return self._dit is not None and self._llm is not None

    def _ensure_loaded(self) -> None:
        """First-call lazy load of both handlers. Heavy imports stay local."""
        if self.is_loaded:
            return

        from acestep.handler import AceStepHandler
        from acestep.llm_inference import LLMHandler

        dit = AceStepHandler()
        dit.initialize_service(
            project_root=str(_REPO_ROOT),
            config_path=self._dit_config,
            device=self._device,
        )

        llm = LLMHandler()
        llm.initialize(
            checkpoint_dir=str(_CHECKPOINTS_DIR),
            lm_model_path=self._lm_model,
            backend="vllm",  # fork auto-routes to mlx on mps + mlx-lm installed
            device=self._device,
        )

        self._dit = dit
        self._llm = llm

    def generate(self, params: dict) -> str:
        """Run a single song generation across all four modes.

        ``params`` is the dict produced by the mode handlers in ``modes.py``.
        The ``params["mode"]`` key (``generate`` | ``cover`` | ``extend`` |
        ``edit``) selects the ACE-Step ``task_type`` and which audio inputs
        get wired through to ``GenerationParams``:

        - ``generate``: ``task_type="text2music"``
        - ``cover``:    ``task_type="cover"`` + ``reference_audio`` +
          ``audio_cover_strength``
        - ``extend``:   ``task_type="repaint"`` + ``src_audio`` set to the
          seed, with ``repainting_start=-1`` / ``repainting_end=-1`` as a
          sentinel meaning "paint after the end of the seed". The actual
          mask shaping ultimately lives inside ACE-Step's repaint path.
        - ``edit``:     ``task_type="repaint"`` + ``src_audio`` + explicit
          ``[segment_start_s, segment_end_s]`` segment bounds.

        Flow-edit (``sub_mode="flow_edit"``) is implemented as a repaint
        pass: the installed ACE-Step ``GenerationParams`` dataclass has no
        native ``flow_edit_*`` fields, so the extra flow-edit knobs carried
        in the internal params dict are ignored at the ``GenerationParams``
        instantiation level and will need wiring once upstream grows them.

        Returns the path to the produced audio file.
        """
        self._ensure_loaded()

        from acestep.inference import (
            GenerationConfig,
            GenerationParams,
            generate_music,
        )

        advanced = params.get("advanced", {}) or {}
        lm_opts = params.get("lm", {}) or {}
        mode = params.get("mode", "generate")

        # Map our internal dict to ACE-Step's GenerationParams.
        # Lyrics "[Instrumental]" is the ACE-Step convention for instrumental.
        lyrics = params.get("lyrics", "") or params.get("extension_lyrics", "") or ""
        if mode == "edit":
            lyrics = params.get("target_lyrics", "") or lyrics
        instrumental = bool(params.get("instrumental", False))
        if instrumental and not lyrics:
            lyrics = "[Instrumental]"

        # Mode-specific task_type + audio inputs.
        # All five fields below MUST resolve before we instantiate
        # GenerationParams so that the dataclass ctor sees consistent values.
        ref_audio: str | None = None
        src_audio: str | None = None
        audio_cover_strength = 0.0
        repainting_start = 0.0
        repainting_end = -1.0

        if mode == "generate":
            task_type = "text2music"
        elif mode == "cover":
            task_type = "cover"
            ref_audio = params.get("ref_audio")
            audio_cover_strength = float(params.get("audio_cover_strength", 0.93))
        elif mode == "extend":
            task_type = "repaint"
            src_audio = params.get("seed_audio")
            # Sentinel: -1 / -1 means "append after the seed audio's end".
            # ACE-Step's repaint path interprets these bounds against the
            # src_audio duration; the actual semantics need verifying once
            # we run a full pass on real hardware (M3 GPU smoke).
            repainting_start = -1.0
            repainting_end = -1.0
        elif mode == "edit":
            task_type = "repaint"
            src_audio = params.get("source_audio")
            repainting_start = float(params.get("segment_start_s", 0.0))
            repainting_end = float(params.get("segment_end_s", 30.0))
            # flow_edit sub-mode: lower audio_cover_strength to allow style
            # drift while still using the repaint task type. The extra
            # flow_* fields in our internal params dict are kept around for
            # future use but not forwarded to GenerationParams (no native
            # support in the installed dataclass).
            if params.get("sub_mode") == "flow_edit":
                audio_cover_strength = 0.3
        else:
            raise ValueError(f"Unknown mode: {mode!r}")

        # Caption can come from the per-mode handlers under different keys.
        caption = (
            params.get("prompt") or params.get("extra_prompt") or params.get("flow_source_caption") or ""
        )
        duration_s = int(params.get("duration_s") or params.get("extra_duration_s") or 30)

        # ``advanced``/``lm`` dicts are sent by app.py's
        # ``_build_advanced_params``. Key changes from the prior contract:
        # - ``inference_steps`` (was ``steps``, defaulted to 8 which made the
        #   XL SFT model behave too turbo-ish; new default 27).
        # - ``guidance_scale`` (was ``cfg``, default 7.0 for stronger prompt
        #   adherence).
        # - ``infer_method`` (new — ``"ode"`` deterministic / ``"sde"``
        #   stochastic; the user can now flip to ``sde`` to actually get
        #   different output each click even with the same seed).
        # - ``use_adg`` (new — Adaptive Dual Guidance; experimental).
        # - ``thinking`` (5Hz LM CoT — default flips to True so the LM can
        #   reason about caption + metadata, which is the actual source of
        #   the "no matter what prompt the style barely changes" symptom).
        # - ``use_cot_metas`` / ``use_cot_caption`` / ``use_cot_language``
        #   keys renamed from ``cot_*`` for consistency with the dataclass.
        gen_params = GenerationParams(
            task_type=task_type,
            caption=caption,
            lyrics=lyrics,
            instrumental=instrumental,
            duration=duration_s,
            seed=int(params.get("seed", -1)),
            inference_steps=int(advanced.get("inference_steps", 27)),
            guidance_scale=float(advanced.get("guidance_scale", 7.0)),
            infer_method=str(advanced.get("infer_method", "ode")),
            use_adg=bool(advanced.get("use_adg", False)),
            shift=float(advanced.get("shift", 1.0)),
            bpm=advanced.get("bpm"),
            keyscale=advanced.get("keyscale", ""),
            timesignature=advanced.get("timesignature", ""),
            vocal_language=advanced.get("vocal_language", "unknown"),
            cfg_interval_start=float(advanced.get("cfg_interval_start", 0.0)),
            cfg_interval_end=float(advanced.get("cfg_interval_end", 1.0)),
            # Mode-specific audio inputs + repaint bounds
            reference_audio=ref_audio,
            src_audio=src_audio,
            audio_cover_strength=audio_cover_strength,
            repainting_start=repainting_start,
            repainting_end=repainting_end,
            # 5Hz language model knobs — defaults flipped to True so the
            # LM actually reasons about each prompt instead of returning
            # blank captions / metadata back to the DiT.
            thinking=bool(lm_opts.get("thinking", True)),
            lm_temperature=float(lm_opts.get("temperature", 0.85)),
            lm_cfg_scale=float(lm_opts.get("cfg", 2.0)),
            lm_top_k=int(lm_opts.get("top_k", 0)),
            lm_top_p=float(lm_opts.get("top_p", 0.9)),
            lm_negative_prompt=lm_opts.get("negative_prompt", "NO USER INPUT"),
            use_cot_metas=bool(lm_opts.get("use_cot_metas", True)),
            use_cot_caption=bool(lm_opts.get("use_cot_caption", True)),
            use_cot_language=bool(lm_opts.get("use_cot_language", True)),
        )

        gen_config = GenerationConfig(
            batch_size=1,
            audio_format=advanced.get("audio_format", "wav"),
            use_random_seed=False,
            seeds=[int(params.get("seed", 1))],
        )

        # generate_music only writes a file when save_dir is provided; otherwise
        # result.audios[i]["path"] is empty and ["tensor"] holds the raw audio.
        # Pass an explicit output dir so the path is always usable.
        _OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

        result = generate_music(
            self._dit,
            self._llm,
            gen_params,
            gen_config,
            save_dir=str(_OUTPUT_DIR),
        )

        if not result.success:
            raise RuntimeError(f"ACE-Step generation failed: {result.error}")
        if not result.audios:
            raise RuntimeError("ACE-Step returned no audio outputs")

        audio = result.audios[0]
        path = audio.get("path") or ""
        if not path:
            # generate_music returned an empty path despite save_dir being passed.
            # Fall back to writing the in-memory tensor so callers always get a
            # valid file path (Gradio cannot serve an empty path).
            import soundfile as sf

            tensor = audio.get("tensor")
            if tensor is None:
                raise RuntimeError("ACE-Step returned neither an audio path nor a tensor")
            sample_rate = int(audio.get("sample_rate", 48000))
            audio_format = advanced.get("audio_format", "wav")
            fallback = _OUTPUT_DIR / f"{audio.get('key', 'fallback')}.{audio_format}"
            data = tensor.detach().cpu().numpy()
            # soundfile expects (frames, channels); acestep tensors are (channels, frames)
            if data.ndim == 2 and data.shape[0] in (1, 2):
                data = data.T
            sf.write(str(fallback), data, sample_rate)
            path = str(fallback)

        return path


_PIPELINE: ACEStepStudio | None = None  # module-level lazy singleton


def get_pipeline() -> ACEStepStudio:
    """Lazy-construct the ACE Music Studio wrapper.

    The wrapper itself is cheap to construct; both handlers (DiT, LM)
    are only loaded on the first ``generate(...)`` call.
    """
    global _PIPELINE
    if _PIPELINE is None:
        _PIPELINE = ACEStepStudio()
    return _PIPELINE