Spaces:
Running on Zero
Running on Zero
feat(pipeline): lazy ace-step singleton with device-aware load
Browse files- ace_pipeline.py +31 -0
- tests/test_ace_pipeline_lazy.py +39 -0
ace_pipeline.py
CHANGED
|
@@ -6,6 +6,8 @@ detection — the pipeline class itself is filled in at M1.
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def detect_device() -> str:
|
| 11 |
"""Returns 'cuda', 'mps', or 'cpu' in priority order."""
|
|
@@ -39,3 +41,32 @@ def vram_limit_for(device: str) -> int | None:
|
|
| 39 |
return max(0, free - 2 * 1024**3)
|
| 40 |
except Exception:
|
| 41 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
|
| 12 |
def detect_device() -> str:
|
| 13 |
"""Returns 'cuda', 'mps', or 'cpu' in priority order."""
|
|
|
|
| 41 |
return max(0, free - 2 * 1024**3)
|
| 42 |
except Exception:
|
| 43 |
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
_PIPELINE = None # module-level lazy singleton
|
| 47 |
+
_DEFAULT_MODEL_ID = "ACE-Step/ACE-Step-v1.5-XL-SFT"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _load_pipeline(device: str, model_path: str):
|
| 51 |
+
"""Construct the ACE-Step pipeline. Heavy import is local so unit tests can mock."""
|
| 52 |
+
from ace_step import ACEStepPipeline # type: ignore[import-not-found]
|
| 53 |
+
|
| 54 |
+
# On Mac, the apple-silicon fork sets dtype + backend automatically.
|
| 55 |
+
# On CUDA we pass bf16 explicitly.
|
| 56 |
+
if device == "cuda":
|
| 57 |
+
pipe = ACEStepPipeline.from_pretrained(model_path, torch_dtype="bf16")
|
| 58 |
+
else:
|
| 59 |
+
pipe = ACEStepPipeline.from_pretrained(model_path)
|
| 60 |
+
|
| 61 |
+
pipe.to(device)
|
| 62 |
+
return pipe
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_pipeline():
|
| 66 |
+
"""Lazy-load the ACE-Step pipeline once per process."""
|
| 67 |
+
global _PIPELINE
|
| 68 |
+
if _PIPELINE is None:
|
| 69 |
+
device = detect_device()
|
| 70 |
+
model_path = os.environ.get("ACE_MODEL_PATH", _DEFAULT_MODEL_ID)
|
| 71 |
+
_PIPELINE = _load_pipeline(device, model_path)
|
| 72 |
+
return _PIPELINE
|
tests/test_ace_pipeline_lazy.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""L2 tests for pipeline lazy load — mock the heavy ACE-Step import."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from unittest.mock import MagicMock
|
| 6 |
+
|
| 7 |
+
import ace_pipeline as ap
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_get_pipeline_loads_lazily_first_call_only(monkeypatch):
|
| 11 |
+
fake_pipe = MagicMock(name="fake_ace_pipeline")
|
| 12 |
+
loader = MagicMock(return_value=fake_pipe)
|
| 13 |
+
monkeypatch.setattr(ap, "_load_pipeline", loader)
|
| 14 |
+
monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
|
| 15 |
+
|
| 16 |
+
p1 = ap.get_pipeline()
|
| 17 |
+
p2 = ap.get_pipeline()
|
| 18 |
+
|
| 19 |
+
assert p1 is fake_pipe
|
| 20 |
+
assert p2 is fake_pipe
|
| 21 |
+
assert loader.call_count == 1, "pipeline should load exactly once"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_get_pipeline_uses_detected_device(monkeypatch):
|
| 25 |
+
monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
|
| 26 |
+
monkeypatch.setattr(ap, "detect_device", lambda: "mps")
|
| 27 |
+
captured = {}
|
| 28 |
+
|
| 29 |
+
def fake_load(device, model_path):
|
| 30 |
+
captured["device"] = device
|
| 31 |
+
captured["model_path"] = model_path
|
| 32 |
+
return MagicMock()
|
| 33 |
+
|
| 34 |
+
monkeypatch.setattr(ap, "_load_pipeline", fake_load)
|
| 35 |
+
|
| 36 |
+
ap.get_pipeline()
|
| 37 |
+
|
| 38 |
+
assert captured["device"] == "mps"
|
| 39 |
+
assert captured["model_path"] is not None
|