Spaces:
Running on Zero
Running on Zero
feat(ui): add advanced controls accordion — inference steps, cfg, infer method, seed, lm cot, schedule, metadata
c287b6a unverified | """ACEStepStudioBackend — dispatch + ZeroGPU lifetime + duration estimator. | |
| Off Spaces, @spaces.GPU is a no-op identity decorator (`spaces` may not be | |
| installed locally). On Spaces, the HF runtime injects it at startup and | |
| the decorator applies for real. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| import time | |
| from typing import Any | |
| try: | |
| import spaces # type: ignore[import-not-found] | |
| _HAS_SPACES = True | |
| except ImportError: # pragma: no cover - covered by manual local testing | |
| spaces = None | |
| _HAS_SPACES = False | |
| import ace_pipeline as ap | |
| import lora_stack | |
| def _maybe_seed(seed: int | None) -> int: | |
| if seed and int(seed) > 0: | |
| return int(seed) | |
| return random.randint(1, 2_147_483_647) | |
| def _duration_estimate(mode: str, params: dict[str, Any]) -> int: | |
| """ZeroGPU per-call duration cap, clamped [60, 180] s.""" | |
| base = 60 | |
| duration_s = int(params.get("duration_s", 30) or 30) | |
| if duration_s > 60: | |
| base = 90 | |
| if duration_s > 120: | |
| base = 120 | |
| if mode == "edit": | |
| base = max(base, 90) | |
| if mode == "extend": | |
| base = max(base, 120) | |
| return min(180, max(60, base)) | |
| class ACEStepStudioBackend: | |
| """Lazy backend singleton. Owns @spaces.GPU and pipeline lifecycle.""" | |
| def dispatch(self, mode: str, params: dict[str, Any]) -> tuple[str, dict[str, Any]]: | |
| params = dict(params) | |
| params["seed"] = _maybe_seed(params.get("seed")) | |
| t0 = time.time() | |
| pipe = ap.get_pipeline() | |
| lora_stack.apply_stack(pipe, params.get("loras", [])) | |
| out_path = self._call_pipe_for_mode(pipe, mode, params) | |
| meta = { | |
| "mode": mode, | |
| "seed": params["seed"], | |
| "duration_s": params.get("duration_s"), | |
| "wall_seconds": round(time.time() - t0, 2), | |
| "estimated_duration_s": _duration_estimate(mode, params), | |
| "loras": [ | |
| {"name": lora.get("name"), "scale": lora.get("scale"), "sha256": lora.get("sha256")} | |
| for lora in params.get("loras", []) | |
| ], | |
| # Echo the advanced + lm dicts back so the user can see which | |
| # knobs were active for a given output and lock-iterate from | |
| # there. The "seed" above is the resolved seed (never -1). | |
| "advanced": params.get("advanced", {}), | |
| "lm": params.get("lm", {}), | |
| "dcw": params.get("dcw", {}), | |
| } | |
| return out_path, meta | |
| def _call_pipe_for_mode(self, pipe, mode: str, params: dict[str, Any]) -> str: | |
| """Dispatch to the pipeline wrapper. | |
| ``pipe`` is the ``ACEStepStudio`` wrapper returned by | |
| ``ace_pipeline.get_pipeline()``. It exposes a single | |
| ``generate(params)`` method that handles the underlying | |
| AceStepHandler + LLMHandler + generate_music plumbing. | |
| All four song modes (``generate``, ``cover``, ``extend``, ``edit``) | |
| flow through ``pipe.generate(params)``. The pipeline wrapper | |
| switches its ``GenerationParams.task_type`` based on ``params["mode"]`` | |
| — see ``ace_pipeline.ACEStepStudio.generate`` for the mapping. The | |
| ``lyrics`` mode is wired separately at M4. | |
| """ | |
| if mode in ("generate", "cover", "extend", "edit"): | |
| params_with_mode = {**params, "mode": mode} | |
| return pipe.generate(params_with_mode) | |
| raise NotImplementedError(f"Mode {mode!r} is not wired yet") | |