zhu-han commited on
Commit
bd9d198
Β·
verified Β·
1 Parent(s): 6422d69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -17
app.py CHANGED
@@ -9,6 +9,12 @@ import os
9
  import tempfile
10
  from typing import Any, Dict
11
 
 
 
 
 
 
 
12
  import torch
13
  import torchaudio
14
 
@@ -16,7 +22,7 @@ from omnivoice import OmniVoice, OmniVoiceGenerationConfig
16
  from omnivoice.cli.demo import build_demo
17
 
18
  logger = logging.getLogger(__name__)
19
- logging.basicConfig(level=logging.INFO)
20
 
21
  # ---------------------------------------------------------------------------
22
  # Hardware detection
@@ -29,15 +35,22 @@ logger.info(f"Using device: {DEVICE}")
29
  # ---------------------------------------------------------------------------
30
  CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "k2-fsa/OmniVoice")
31
 
32
- logger.info(f"Loading model from {CHECKPOINT} on {DEVICE} ...")
33
- model = OmniVoice.from_pretrained(
34
- CHECKPOINT,
35
- device_map=DEVICE,
36
- dtype=torch.float16,
37
- load_asr=True,
38
- )
39
- logger.info("Model loaded on %s.", DEVICE)
40
- sampling_rate = model.sampling_rate
 
 
 
 
 
 
 
41
 
42
  # ---------------------------------------------------------------------------
43
  # Generation logic (outside build_demo so we can wrap with spaces.GPU)
@@ -107,17 +120,24 @@ def _gen_core(
107
  # ZeroGPU wrapper
108
  # ---------------------------------------------------------------------------
109
  generate_fn = None
110
- try:
111
- import spaces
112
-
113
  @spaces.GPU()
114
- def _gen_gpu(*args, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
115
  return _gen_core(*args, **kwargs)
116
 
117
- generate_fn = _gen_gpu
118
  logger.info("Using spaces.GPU() wrapper.")
119
- except ImportError:
120
- logger.info("spaces module not found, running without GPU wrapper.")
121
 
122
  # ---------------------------------------------------------------------------
123
  # Build and launch demo β€” reuses the full UI from omnivoice.cli.demo
 
9
  import tempfile
10
  from typing import Any, Dict
11
 
12
+ try:
13
+ import spaces
14
+ _USING_ZERO_GPU = True
15
+ except ImportError:
16
+ _USING_ZERO_GPU = False
17
+
18
  import torch
19
  import torchaudio
20
 
 
22
  from omnivoice.cli.demo import build_demo
23
 
24
  logger = logging.getLogger(__name__)
25
+ logging.basicConfig(level=logging.INFO)
26
 
27
  # ---------------------------------------------------------------------------
28
  # Hardware detection
 
35
  # ---------------------------------------------------------------------------
36
  CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "k2-fsa/OmniVoice")
37
 
38
+ model = None
39
+ if not _USING_ZERO_GPU:
40
+ # Non-ZeroGPU: load model at startup on the best available device
41
+ logger.info(f"Loading model from {CHECKPOINT} on {DEVICE} ...")
42
+ model = OmniVoice.from_pretrained(
43
+ CHECKPOINT,
44
+ device_map=DEVICE,
45
+ dtype=torch.float16,
46
+ load_asr=True,
47
+ )
48
+ logger.info("Model loaded on %s.", DEVICE)
49
+ else:
50
+ logger.info("ZeroGPU mode: model will be loaded inside @spaces.GPU() function.")
51
+
52
+ sampling_rate = 16000 # fallback; will be overwritten after model loads
53
+
54
 
55
  # ---------------------------------------------------------------------------
56
  # Generation logic (outside build_demo so we can wrap with spaces.GPU)
 
120
  # ZeroGPU wrapper
121
  # ---------------------------------------------------------------------------
122
  generate_fn = None
123
+ if _USING_ZERO_GPU:
 
 
124
  @spaces.GPU()
125
+ def generate_fn(*args, **kwargs):
126
+ # Lazy-load model on first call (inside GPU context)
127
+ global model, sampling_rate
128
+ if model is None:
129
+ logger.info(f"Loading model from {CHECKPOINT} on cuda (ZeroGPU) ...")
130
+ model = OmniVoice.from_pretrained(
131
+ CHECKPOINT,
132
+ device_map="cuda",
133
+ dtype=torch.float16,
134
+ load_asr=True,
135
+ )
136
+ sampling_rate = model.sampling_rate
137
+ logger.info("Model loaded on cuda (ZeroGPU).")
138
  return _gen_core(*args, **kwargs)
139
 
 
140
  logger.info("Using spaces.GPU() wrapper.")
 
 
141
 
142
  # ---------------------------------------------------------------------------
143
  # Build and launch demo β€” reuses the full UI from omnivoice.cli.demo