techfreakworm commited on
Commit
46d16df
·
unverified ·
1 Parent(s): 8c574cb

feat(pipeline): add device autodetect with mps-safe vram limit

Browse files
Files changed (2) hide show
  1. ace_pipeline.py +40 -0
  2. tests/test_ace_pipeline.py +19 -0
ace_pipeline.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ACE-Step pipeline lifecycle: device autodetect, lazy load, cache mirror.
2
+
3
+ Mirrors z-image-studio's `models.py` pattern. M0 only implements device
4
+ detection — the pipeline class itself is filled in at M1.
5
+ """
6
+ from __future__ import annotations
7
+
8
+
9
+ def detect_device() -> str:
10
+ """Returns 'cuda', 'mps', or 'cpu' in priority order."""
11
+ try:
12
+ import torch # local import: keep module import cheap for CI
13
+ except ImportError:
14
+ return "cpu"
15
+
16
+ if torch.cuda.is_available():
17
+ return "cuda"
18
+ # macOS: torch.backends.mps appeared in 2.0; guard for the rare absence
19
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
20
+ return "mps"
21
+ return "cpu"
22
+
23
+
24
+ def vram_limit_for(device: str) -> int | None:
25
+ """Returns a VRAM cap in bytes for CUDA, None otherwise.
26
+
27
+ `torch.mps` has no `mem_get_info` — calling DiffSynth-style free-VRAM
28
+ gates with a numeric limit would crash on MPS. Returning None lets the
29
+ pipeline short-circuit those checks.
30
+ """
31
+ if device != "cuda":
32
+ return None
33
+ try:
34
+ import torch
35
+
36
+ free, _total = torch.cuda.mem_get_info()
37
+ # Leave 2 GiB headroom for activations
38
+ return max(0, free - 2 * 1024**3)
39
+ except Exception:
40
+ return None
tests/test_ace_pipeline.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """L1 tests for device autodetect — no torch needed if we mock importlib."""
2
+ from __future__ import annotations
3
+
4
+ import ace_pipeline as ap
5
+
6
+
7
+ def test_detect_device_returns_one_of_cuda_mps_cpu():
8
+ device = ap.detect_device()
9
+ assert device in {"cuda", "mps", "cpu"}
10
+
11
+
12
+ def test_vram_limit_for_mps_is_none():
13
+ """MPS has no torch.mps.mem_get_info; return None so DiffSynth-style gates
14
+ short-circuit instead of crashing (z-image-studio paid this debug cycle)."""
15
+ assert ap.vram_limit_for("mps") is None
16
+
17
+
18
+ def test_vram_limit_for_cpu_is_none():
19
+ assert ap.vram_limit_for("cpu") is None