File size: 3,452 Bytes
dfa2ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96012ce
dfa2ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96012ce
dfa2ff6
 
 
 
 
 
 
 
 
 
 
c287b6a
 
 
 
dfa2ff6
 
 
 
 
 
99375d0
 
 
 
 
 
 
26dc3a4
 
 
 
 
99375d0
26dc3a4
 
 
dfa2ff6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""ACEStepStudioBackend — dispatch + ZeroGPU lifetime + duration estimator.

Off Spaces, @spaces.GPU is a no-op identity decorator (`spaces` may not be
installed locally). On Spaces, the HF runtime injects it at startup and
the decorator applies for real.
"""

from __future__ import annotations

import random
import time
from typing import Any

try:
    import spaces  # type: ignore[import-not-found]

    _HAS_SPACES = True
except ImportError:  # pragma: no cover - covered by manual local testing
    spaces = None
    _HAS_SPACES = False

import ace_pipeline as ap
import lora_stack


def _maybe_seed(seed: int | None) -> int:
    if seed and int(seed) > 0:
        return int(seed)
    return random.randint(1, 2_147_483_647)


def _duration_estimate(mode: str, params: dict[str, Any]) -> int:
    """ZeroGPU per-call duration cap, clamped [60, 180] s."""
    base = 60
    duration_s = int(params.get("duration_s", 30) or 30)
    if duration_s > 60:
        base = 90
    if duration_s > 120:
        base = 120
    if mode == "edit":
        base = max(base, 90)
    if mode == "extend":
        base = max(base, 120)
    return min(180, max(60, base))


class ACEStepStudioBackend:
    """Lazy backend singleton. Owns @spaces.GPU and pipeline lifecycle."""

    def dispatch(self, mode: str, params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
        params = dict(params)
        params["seed"] = _maybe_seed(params.get("seed"))
        t0 = time.time()
        pipe = ap.get_pipeline()
        lora_stack.apply_stack(pipe, params.get("loras", []))
        out_path = self._call_pipe_for_mode(pipe, mode, params)
        meta = {
            "mode": mode,
            "seed": params["seed"],
            "duration_s": params.get("duration_s"),
            "wall_seconds": round(time.time() - t0, 2),
            "estimated_duration_s": _duration_estimate(mode, params),
            "loras": [
                {"name": lora.get("name"), "scale": lora.get("scale"), "sha256": lora.get("sha256")}
                for lora in params.get("loras", [])
            ],
            # Echo the advanced + lm dicts back so the user can see which
            # knobs were active for a given output and lock-iterate from
            # there. The "seed" above is the resolved seed (never -1).
            "advanced": params.get("advanced", {}),
            "lm": params.get("lm", {}),
            "dcw": params.get("dcw", {}),
        }
        return out_path, meta

    def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str:
        """Dispatch to the pipeline wrapper.

        ``pipe`` is the ``ACEStepStudio`` wrapper returned by
        ``ace_pipeline.get_pipeline()``. It exposes a single
        ``generate(params)`` method that handles the underlying
        AceStepHandler + LLMHandler + generate_music plumbing.

        All four song modes (``generate``, ``cover``, ``extend``, ``edit``)
        flow through ``pipe.generate(params)``. The pipeline wrapper
        switches its ``GenerationParams.task_type`` based on ``params["mode"]``
        — see ``ace_pipeline.ACEStepStudio.generate`` for the mapping. The
        ``lyrics`` mode is wired separately at M4.
        """
        if mode in ("generate", "cover", "extend", "edit"):
            params_with_mode = {**params, "mode": mode}
            return pipe.generate(params_with_mode)
        raise NotImplementedError(f"Mode {mode!r} is not wired yet")