techfreakworm commited on
Commit
eb3bcb4
·
unverified ·
1 Parent(s): 2e5af7a

feat(pipeline): lazy ace-step singleton with device-aware load

Browse files
Files changed (2) hide show
  1. ace_pipeline.py +31 -0
  2. tests/test_ace_pipeline_lazy.py +39 -0
ace_pipeline.py CHANGED
@@ -6,6 +6,8 @@ detection — the pipeline class itself is filled in at M1.
6
 
7
  from __future__ import annotations
8
 
 
 
9
 
10
  def detect_device() -> str:
11
  """Returns 'cuda', 'mps', or 'cpu' in priority order."""
@@ -39,3 +41,32 @@ def vram_limit_for(device: str) -> int | None:
39
  return max(0, free - 2 * 1024**3)
40
  except Exception:
41
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from __future__ import annotations
8
 
9
+ import os
10
+
11
 
12
  def detect_device() -> str:
13
  """Returns 'cuda', 'mps', or 'cpu' in priority order."""
 
41
  return max(0, free - 2 * 1024**3)
42
  except Exception:
43
  return None
44
+
45
+
46
+ _PIPELINE = None # module-level lazy singleton
47
+ _DEFAULT_MODEL_ID = "ACE-Step/ACE-Step-v1.5-XL-SFT"
48
+
49
+
50
+ def _load_pipeline(device: str, model_path: str):
51
+ """Construct the ACE-Step pipeline. Heavy import is local so unit tests can mock."""
52
+ from ace_step import ACEStepPipeline # type: ignore[import-not-found]
53
+
54
+ # On Mac, the apple-silicon fork sets dtype + backend automatically.
55
+ # On CUDA we pass bf16 explicitly.
56
+ if device == "cuda":
57
+ pipe = ACEStepPipeline.from_pretrained(model_path, torch_dtype="bf16")
58
+ else:
59
+ pipe = ACEStepPipeline.from_pretrained(model_path)
60
+
61
+ pipe.to(device)
62
+ return pipe
63
+
64
+
65
+ def get_pipeline():
66
+ """Lazy-load the ACE-Step pipeline once per process."""
67
+ global _PIPELINE
68
+ if _PIPELINE is None:
69
+ device = detect_device()
70
+ model_path = os.environ.get("ACE_MODEL_PATH", _DEFAULT_MODEL_ID)
71
+ _PIPELINE = _load_pipeline(device, model_path)
72
+ return _PIPELINE
tests/test_ace_pipeline_lazy.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L2 tests for pipeline lazy load — mock the heavy ACE-Step import."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import MagicMock
6
+
7
+ import ace_pipeline as ap
8
+
9
+
10
+ def test_get_pipeline_loads_lazily_first_call_only(monkeypatch):
11
+ fake_pipe = MagicMock(name="fake_ace_pipeline")
12
+ loader = MagicMock(return_value=fake_pipe)
13
+ monkeypatch.setattr(ap, "_load_pipeline", loader)
14
+ monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
15
+
16
+ p1 = ap.get_pipeline()
17
+ p2 = ap.get_pipeline()
18
+
19
+ assert p1 is fake_pipe
20
+ assert p2 is fake_pipe
21
+ assert loader.call_count == 1, "pipeline should load exactly once"
22
+
23
+
24
+ def test_get_pipeline_uses_detected_device(monkeypatch):
25
+ monkeypatch.setattr(ap, "_PIPELINE", None, raising=False)
26
+ monkeypatch.setattr(ap, "detect_device", lambda: "mps")
27
+ captured = {}
28
+
29
+ def fake_load(device, model_path):
30
+ captured["device"] = device
31
+ captured["model_path"] = model_path
32
+ return MagicMock()
33
+
34
+ monkeypatch.setattr(ap, "_load_pipeline", fake_load)
35
+
36
+ ap.get_pipeline()
37
+
38
+ assert captured["device"] == "mps"
39
+ assert captured["model_path"] is not None