multimodalart HF Staff commited on
Commit
5728513
·
1 Parent(s): b134d23

Fix Blackwell support: pin numpy<2, install stable-audio-tools --no-deps, fall back to SDPA

Browse files
Files changed (2) hide show
  1. app.py +7 -39
  2. requirements.txt +23 -2
app.py CHANGED
@@ -8,7 +8,6 @@ duration); steps / CFG / sampler / seed live in an Advanced accordion.
8
 
9
  from __future__ import annotations
10
 
11
- # spaces must be imported before any CUDA-touching module.
12
  import spaces # noqa: F401
13
 
14
  import os
@@ -19,55 +18,24 @@ import time
19
  import types
20
  from dataclasses import dataclass
21
 
22
-
23
- # ---------------------------------------------------------------------------
24
- # Runtime install: flash-attn (then stable-audio-tools).
25
- #
26
- # ``stable-audio-tools`` declares flash-attn as a hard dep. flash-attn's
27
- # setup.py does ``import torch``, but pip's isolated build env on Spaces has
28
- # no torch — the build fails before pyproject.toml can declare it. We install
29
- # flash-attn here with ``FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE`` so the wheel
30
- # materialises in seconds, then install stable-audio-tools (which now sees
31
- # flash-attn as satisfied). The missing CUDA extension is stubbed below so
32
- # ``import flash_attn`` succeeds.
33
- # ---------------------------------------------------------------------------
34
-
35
-
36
- def _ensure_flash_attn() -> None:
37
- if "flash_attn_2_cuda" not in sys.modules:
38
- sys.modules["flash_attn_2_cuda"] = types.ModuleType("flash_attn_2_cuda")
39
- try:
40
- import flash_attn # noqa: F401
41
- return
42
- except ImportError:
43
- pass
44
- print("[startup] installing flash-attn (CUDA build skipped) …", flush=True)
45
- env = dict(os.environ)
46
- env["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
47
- subprocess.check_call(
48
- [sys.executable, "-m", "pip", "install", "--quiet",
49
- "--no-build-isolation", "flash-attn>=2.7.0"],
50
- env=env,
51
- )
52
- import flash_attn # noqa: F401
53
- print("[startup] flash-attn installed.", flush=True)
54
-
55
-
56
  def _ensure_stable_audio_tools() -> None:
57
  try:
58
  import stable_audio_tools # noqa: F401
59
  return
60
  except ImportError:
61
  pass
62
- print("[startup] installing stable-audio-tools …", flush=True)
 
 
 
 
63
  subprocess.check_call(
64
- [sys.executable, "-m", "pip", "install", "--quiet", "stable-audio-tools"],
 
65
  )
66
  import stable_audio_tools # noqa: F401
67
  print("[startup] stable-audio-tools installed.", flush=True)
68
 
69
-
70
- _ensure_flash_attn()
71
  _ensure_stable_audio_tools()
72
 
73
 
 
8
 
9
  from __future__ import annotations
10
 
 
11
  import spaces # noqa: F401
12
 
13
  import os
 
18
  import types
19
  from dataclasses import dataclass
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def _ensure_stable_audio_tools() -> None:
22
  try:
23
  import stable_audio_tools # noqa: F401
24
  return
25
  except ImportError:
26
  pass
27
+ # stable-audio-tools 0.0.20 strict-pins torch==2.7.1 / torchaudio==2.7.1,
28
+ # which lack sm_120 (Blackwell) kernels. Install with --no-deps; the
29
+ # transitive deps are listed in requirements.txt and resolved against the
30
+ # sm_120-capable torch at build time.
31
+ print("[startup] installing stable-audio-tools (--no-deps) …", flush=True)
32
  subprocess.check_call(
33
+ [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps",
34
+ "stable-audio-tools"],
35
  )
36
  import stable_audio_tools # noqa: F401
37
  print("[startup] stable-audio-tools installed.", flush=True)
38
 
 
 
39
  _ensure_stable_audio_tools()
40
 
41
 
requirements.txt CHANGED
@@ -1,6 +1,27 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu132
2
  einops
3
  soundfile
4
- PyWavelets>=1.7.0
 
5
  torch
6
  torchaudio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu132
2
  einops
3
  soundfile
4
+ numpy<2
5
+ pytorch_lightning
6
  torch
7
  torchaudio
8
+ torchvision
9
+ # stable-audio-tools transitive deps (sat itself strict-pins torch==2.7.1, so
10
+ # we install it with --no-deps in app.py and resolve its deps here against
11
+ # our sm_120-capable torch).
12
+ alias-free-torch
13
+ dill
14
+ einops-exts
15
+ huggingface_hub
16
+ importlib-resources
17
+ nnAudio
18
+ PyWavelets
19
+ safetensors
20
+ scipy
21
+ sentencepiece
22
+ soxr
23
+ torchsde
24
+ tqdm
25
+ transformers
26
+ v-diffusion-pytorch
27
+ vector-quantize-pytorch