multimodalart HF Staff commited on
Commit
72c410e
·
verified ·
1 Parent(s): cba3206

install flash-attn + stable-audio-tools at startup (SKIP_CUDA_BUILD)

Browse files
Files changed (2) hide show
  1. app.py +55 -0
  2. requirements.txt +4 -2
app.py CHANGED
@@ -12,10 +12,65 @@ from __future__ import annotations
12
  import spaces # noqa: F401
13
 
14
  import os
 
 
15
  import tempfile
16
  import time
 
17
  from dataclasses import dataclass
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  import gradio as gr
20
  import torch
21
  import torchaudio
 
12
  import spaces # noqa: F401
13
 
14
  import os
15
+ import subprocess
16
+ import sys
17
  import tempfile
18
  import time
19
+ import types
20
  from dataclasses import dataclass
21
 
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Runtime install: flash-attn (then stable-audio-tools).
25
+ #
26
+ # ``stable-audio-tools`` declares flash-attn as a hard dep. flash-attn's
27
+ # setup.py does ``import torch``, but pip's isolated build env on Spaces has
28
+ # no torch — the build fails before pyproject.toml can declare it. We install
29
+ # flash-attn here with ``FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE`` so the wheel
30
+ # materialises in seconds, then install stable-audio-tools (which now sees
31
+ # flash-attn as satisfied). The missing CUDA extension is stubbed below so
32
+ # ``import flash_attn`` succeeds.
33
+ # ---------------------------------------------------------------------------
34
+
35
+
36
+ def _ensure_flash_attn() -> None:
37
+ if "flash_attn_2_cuda" not in sys.modules:
38
+ sys.modules["flash_attn_2_cuda"] = types.ModuleType("flash_attn_2_cuda")
39
+ try:
40
+ import flash_attn # noqa: F401
41
+ return
42
+ except ImportError:
43
+ pass
44
+ print("[startup] installing flash-attn (CUDA build skipped) …", flush=True)
45
+ env = dict(os.environ)
46
+ env["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
47
+ subprocess.check_call(
48
+ [sys.executable, "-m", "pip", "install", "--quiet",
49
+ "--no-build-isolation", "flash-attn>=2.7.0"],
50
+ env=env,
51
+ )
52
+ import flash_attn # noqa: F401
53
+ print("[startup] flash-attn installed.", flush=True)
54
+
55
+
56
+ def _ensure_stable_audio_tools() -> None:
57
+ try:
58
+ import stable_audio_tools # noqa: F401
59
+ return
60
+ except ImportError:
61
+ pass
62
+ print("[startup] installing stable-audio-tools …", flush=True)
63
+ subprocess.check_call(
64
+ [sys.executable, "-m", "pip", "install", "--quiet", "stable-audio-tools"],
65
+ )
66
+ import stable_audio_tools # noqa: F401
67
+ print("[startup] stable-audio-tools installed.", flush=True)
68
+
69
+
70
+ _ensure_flash_attn()
71
+ _ensure_stable_audio_tools()
72
+
73
+
74
  import gradio as gr
75
  import torch
76
  import torchaudio
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
  # torch / gradio / spaces are preinstalled on ZeroGPU Spaces.
2
- stable-audio-tools
 
 
3
  einops
4
  # PyWavelets 1.7+ ships wheels built for NumPy 2.x; older versions throw a
5
- # dtype-size ABI error on the ZeroGPU image. Force a fresh wheel.
6
  PyWavelets>=1.7.0
 
1
  # torch / gradio / spaces are preinstalled on ZeroGPU Spaces.
2
+ # stable-audio-tools is installed at app startup (after flash-attn is
3
+ # pre-installed) — its setup.py pulls flash-attn which fails to build in pip's
4
+ # isolated env. See app.py.
5
  einops
6
  # PyWavelets 1.7+ ships wheels built for NumPy 2.x; older versions throw a
7
+ # dtype-size ABI error on the ZeroGPU image.
8
  PyWavelets>=1.7.0