zhu-han commited on
Commit
69feda9
Β·
verified Β·
1 Parent(s): bd9d198

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -59
app.py CHANGED
@@ -4,56 +4,32 @@ HuggingFace Space entry point for OmniVoice demo.
4
 
5
  """
6
 
7
- import logging
8
  import os
9
- import tempfile
10
  from typing import Any, Dict
11
 
12
- try:
13
- import spaces
14
- _USING_ZERO_GPU = True
15
- except ImportError:
16
- _USING_ZERO_GPU = False
17
-
18
  import torch
19
- import torchaudio
20
-
21
  from omnivoice import OmniVoice, OmniVoiceGenerationConfig
22
  from omnivoice.cli.demo import build_demo
23
 
24
- logger = logging.getLogger(__name__)
25
- logging.basicConfig(level=logging.INFO)
26
-
27
- # ---------------------------------------------------------------------------
28
- # Hardware detection
29
- # ---------------------------------------------------------------------------
30
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
- logger.info(f"Using device: {DEVICE}")
32
-
33
  # ---------------------------------------------------------------------------
34
  # Model loading
35
  # ---------------------------------------------------------------------------
36
  CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "k2-fsa/OmniVoice")
37
 
38
- model = None
39
- if not _USING_ZERO_GPU:
40
- # Non-ZeroGPU: load model at startup on the best available device
41
- logger.info(f"Loading model from {CHECKPOINT} on {DEVICE} ...")
42
- model = OmniVoice.from_pretrained(
43
- CHECKPOINT,
44
- device_map=DEVICE,
45
- dtype=torch.float16,
46
- load_asr=True,
47
- )
48
- logger.info("Model loaded on %s.", DEVICE)
49
- else:
50
- logger.info("ZeroGPU mode: model will be loaded inside @spaces.GPU() function.")
51
-
52
- sampling_rate = 16000 # fallback; will be overwritten after model loads
53
-
54
 
55
  # ---------------------------------------------------------------------------
56
- # Generation logic (outside build_demo so we can wrap with spaces.GPU)
57
  # ---------------------------------------------------------------------------
58
 
59
 
@@ -107,40 +83,27 @@ def _gen_core(
107
  kw["instruct"] = instruct.strip()
108
 
109
  try:
110
- out_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
111
  audio = model.generate(**kw)
112
- torchaudio.save(out_path, audio[0], sampling_rate)
113
  except Exception as e:
114
  return None, f"Error: {type(e).__name__}: {e}"
115
 
116
- return out_path, "Done."
 
 
117
 
118
 
119
  # ---------------------------------------------------------------------------
120
  # ZeroGPU wrapper
121
  # ---------------------------------------------------------------------------
122
- generate_fn = None
123
- if _USING_ZERO_GPU:
124
- @spaces.GPU()
125
- def generate_fn(*args, **kwargs):
126
- # Lazy-load model on first call (inside GPU context)
127
- global model, sampling_rate
128
- if model is None:
129
- logger.info(f"Loading model from {CHECKPOINT} on cuda (ZeroGPU) ...")
130
- model = OmniVoice.from_pretrained(
131
- CHECKPOINT,
132
- device_map="cuda",
133
- dtype=torch.float16,
134
- load_asr=True,
135
- )
136
- sampling_rate = model.sampling_rate
137
- logger.info("Model loaded on cuda (ZeroGPU).")
138
- return _gen_core(*args, **kwargs)
139
-
140
- logger.info("Using spaces.GPU() wrapper.")
141
 
142
  # ---------------------------------------------------------------------------
143
- # Build and launch demo β€” reuses the full UI from omnivoice.cli.demo
144
  # ---------------------------------------------------------------------------
145
  demo = build_demo(model, CHECKPOINT, generate_fn=generate_fn)
146
 
 
4
 
5
  """
6
 
 
7
  import os
 
8
  from typing import Any, Dict
9
 
10
+ import numpy as np
11
+ import spaces
 
 
 
 
12
  import torch
 
 
13
  from omnivoice import OmniVoice, OmniVoiceGenerationConfig
14
  from omnivoice.cli.demo import build_demo
15
 
 
 
 
 
 
 
 
 
 
16
  # ---------------------------------------------------------------------------
17
  # Model loading
18
  # ---------------------------------------------------------------------------
19
  CHECKPOINT = os.environ.get("OMNIVOICE_MODEL", "k2-fsa/OmniVoice")
20
 
21
+ print(f"Loading model from {CHECKPOINT} to cuda ...")
22
+ model = OmniVoice.from_pretrained(
23
+ CHECKPOINT,
24
+ device_map="cuda",
25
+ dtype=torch.float16,
26
+ load_asr=True,
27
+ )
28
+ sampling_rate = model.sampling_rate
29
+ print("Model loaded successfully!")
 
 
 
 
 
 
 
30
 
31
  # ---------------------------------------------------------------------------
32
+ # Generation logic
33
  # ---------------------------------------------------------------------------
34
 
35
 
 
83
  kw["instruct"] = instruct.strip()
84
 
85
  try:
 
86
  audio = model.generate(**kw)
 
87
  except Exception as e:
88
  return None, f"Error: {type(e).__name__}: {e}"
89
 
90
+ waveform = audio[0].squeeze(0).numpy()
91
+ waveform = (waveform * 32767).astype(np.int16)
92
+ return (sampling_rate, waveform), "Done."
93
 
94
 
95
  # ---------------------------------------------------------------------------
96
  # ZeroGPU wrapper
97
  # ---------------------------------------------------------------------------
98
+
99
+
100
+ @spaces.GPU(duration=60)
101
+ def generate_fn(*args, **kwargs):
102
+ return _gen_core(*args, **kwargs)
103
+
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # ---------------------------------------------------------------------------
106
+ # Build and launch demo
107
  # ---------------------------------------------------------------------------
108
  demo = build_demo(model, CHECKPOINT, generate_fn=generate_fn)
109