Spaces:
Running on Zero
Running on Zero
feat(pipeline): add device autodetect with mps-safe vram limit
Browse files- ace_pipeline.py +40 -0
- tests/test_ace_pipeline.py +19 -0
ace_pipeline.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ACE-Step pipeline lifecycle: device autodetect, lazy load, cache mirror.
|
| 2 |
+
|
| 3 |
+
Mirrors z-image-studio's `models.py` pattern. M0 only implements device
|
| 4 |
+
detection — the pipeline class itself is filled in at M1.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def detect_device() -> str:
|
| 10 |
+
"""Returns 'cuda', 'mps', or 'cpu' in priority order."""
|
| 11 |
+
try:
|
| 12 |
+
import torch # local import: keep module import cheap for CI
|
| 13 |
+
except ImportError:
|
| 14 |
+
return "cpu"
|
| 15 |
+
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
return "cuda"
|
| 18 |
+
# macOS: torch.backends.mps appeared in 2.0; guard for the rare absence
|
| 19 |
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 20 |
+
return "mps"
|
| 21 |
+
return "cpu"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def vram_limit_for(device: str) -> int | None:
|
| 25 |
+
"""Returns a VRAM cap in bytes for CUDA, None otherwise.
|
| 26 |
+
|
| 27 |
+
`torch.mps` has no `mem_get_info` — calling DiffSynth-style free-VRAM
|
| 28 |
+
gates with a numeric limit would crash on MPS. Returning None lets the
|
| 29 |
+
pipeline short-circuit those checks.
|
| 30 |
+
"""
|
| 31 |
+
if device != "cuda":
|
| 32 |
+
return None
|
| 33 |
+
try:
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
free, _total = torch.cuda.mem_get_info()
|
| 37 |
+
# Leave 2 GiB headroom for activations
|
| 38 |
+
return max(0, free - 2 * 1024**3)
|
| 39 |
+
except Exception:
|
| 40 |
+
return None
|
tests/test_ace_pipeline.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""L1 tests for device autodetect — no torch needed if we mock importlib."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import ace_pipeline as ap
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_detect_device_returns_one_of_cuda_mps_cpu():
|
| 8 |
+
device = ap.detect_device()
|
| 9 |
+
assert device in {"cuda", "mps", "cpu"}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_vram_limit_for_mps_is_none():
|
| 13 |
+
"""MPS has no torch.mps.mem_get_info; return None so DiffSynth-style gates
|
| 14 |
+
short-circuit instead of crashing (z-image-studio paid this debug cycle)."""
|
| 15 |
+
assert ap.vram_limit_for("mps") is None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_vram_limit_for_cpu_is_none():
|
| 19 |
+
assert ap.vram_limit_for("cpu") is None
|