shreyask commited on
Commit
5c97b55
·
verified ·
1 Parent(s): 6525f03

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. src/App.tsx +21 -15
  2. src/WaveformPlayer.tsx +173 -0
  3. src/index.css +52 -8
  4. src/worker.ts +22 -11
src/App.tsx CHANGED
@@ -1,26 +1,32 @@
1
  import { useState, useRef, useCallback, useEffect } from "react";
 
2
 
3
  const MODELS: Record<string, string> = {
4
- "Nano Int8 (15M · Fastest)": "KittenML/kitten-tts-nano-0.8-int8",
5
- "Nano FP32 (15M)": "KittenML/kitten-tts-nano-0.8-fp32",
6
- "Micro (40M · Balanced)": "KittenML/kitten-tts-micro-0.8",
7
- "Mini (80M · Best Quality)": "KittenML/kitten-tts-mini-0.8",
8
  };
9
 
10
- const DEFAULT_MODEL = "Nano FP32 (15M)";
11
 
12
  const EXAMPLES = [
13
  {
14
  text: "Space is a three-dimensional continuum containing positions and directions.",
 
15
  voice: "Jasper",
 
16
  },
17
  {
18
  text: "She picked up her coffee and walked toward the window.",
 
19
  voice: "Luna",
 
20
  },
21
  {
22
- text: "The sun set slowly over the calm, quiet lake.",
 
23
  voice: "Bella",
 
24
  },
25
  ];
26
 
@@ -140,6 +146,10 @@ export default function App() {
140
  const handleExample = (ex: (typeof EXAMPLES)[0]) => {
141
  setText(ex.text);
142
  setVoice(ex.voice);
 
 
 
 
143
  };
144
 
145
  const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
@@ -243,14 +253,7 @@ export default function App() {
243
  <div className="output-section">
244
  <label>Output</label>
245
  {audioUrl ? (
246
- <div className="audio-result">
247
- <audio controls src={audioUrl} className="audio-player" />
248
- {duration !== null && (
249
- <span className="duration">
250
- Generated in {(duration / 1000).toFixed(1)}s
251
- </span>
252
- )}
253
- </div>
254
  ) : (
255
  <div className="audio-placeholder">
256
  {status === "loading" || status === "generating"
@@ -268,10 +271,13 @@ export default function App() {
268
  key={i}
269
  className="example-btn"
270
  onClick={() => handleExample(ex)}
271
- disabled={status !== "ready"}
272
  >
273
  <span className="example-voice">{ex.voice}</span>
274
  <span className="example-text">{ex.text}</span>
 
 
 
275
  </button>
276
  ))}
277
  </div>
 
1
  import { useState, useRef, useCallback, useEffect } from "react";
2
+ import WaveformPlayer from "./WaveformPlayer";
3
 
4
  const MODELS: Record<string, string> = {
5
+ "Nano (15M - Fastest)": "onnx-community/KittenTTS-Nano-v0.8-ONNX",
6
+ "Micro (40M - Balanced)": "onnx-community/KittenTTS-Micro-v0.8-ONNX",
7
+ "Mini (80M - Best Quality)": "onnx-community/KittenTTS-Mini-v0.8-ONNX",
 
8
  };
9
 
10
+ const DEFAULT_MODEL = "Micro (40M - Balanced)";
11
 
12
  const EXAMPLES = [
13
  {
14
  text: "Space is a three-dimensional continuum containing positions and directions.",
15
+ model: "Micro (40M - Balanced)",
16
  voice: "Jasper",
17
+ speed: 1.0,
18
  },
19
  {
20
  text: "She picked up her coffee and walked toward the window.",
21
+ model: "Mini (80M - Best Quality)",
22
  voice: "Luna",
23
+ speed: 1.0,
24
  },
25
  {
26
+ text: "The sun set slowly over the calm, quiet lake",
27
+ model: "Nano (15M - Fastest)",
28
  voice: "Bella",
29
+ speed: 1.1,
30
  },
31
  ];
32
 
 
146
  const handleExample = (ex: (typeof EXAMPLES)[0]) => {
147
  setText(ex.text);
148
  setVoice(ex.voice);
149
+ setSpeed(ex.speed);
150
+ if (ex.model !== model) {
151
+ handleModelChange(ex.model);
152
+ }
153
  };
154
 
155
  const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
 
253
  <div className="output-section">
254
  <label>Output</label>
255
  {audioUrl ? (
256
+ <WaveformPlayer audioUrl={audioUrl} duration={duration} />
 
 
 
 
 
 
 
257
  ) : (
258
  <div className="audio-placeholder">
259
  {status === "loading" || status === "generating"
 
271
  key={i}
272
  className="example-btn"
273
  onClick={() => handleExample(ex)}
274
+ disabled={status === "loading" || status === "generating"}
275
  >
276
  <span className="example-voice">{ex.voice}</span>
277
  <span className="example-text">{ex.text}</span>
278
+ <span className="example-meta">
279
+ {ex.model.split(" (")[0]}{ex.speed !== 1.0 ? ` · ${ex.speed}x` : ""}
280
+ </span>
281
  </button>
282
  ))}
283
  </div>
src/WaveformPlayer.tsx ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useRef, useEffect, useState, useCallback } from "react";
2
+
3
+ interface WaveformPlayerProps {
4
+ audioUrl: string;
5
+ duration?: number | null;
6
+ }
7
+
8
+ export default function WaveformPlayer({ audioUrl, duration }: WaveformPlayerProps) {
9
+ const canvasRef = useRef<HTMLCanvasElement>(null);
10
+ const audioRef = useRef<HTMLAudioElement>(null);
11
+ const animRef = useRef<number>(0);
12
+ const waveformRef = useRef<number[]>([]);
13
+
14
+ const [playing, setPlaying] = useState(false);
15
+ const [currentTime, setCurrTime] = useState(0);
16
+ const [totalDuration, setTotalDuration] = useState(0);
17
+ const [hovering, setHovering] = useState(false);
18
+ const [hoverX, setHoverX] = useState(0);
19
+
20
+ // Decode audio and compute waveform peaks
21
+ useEffect(() => {
22
+ if (!audioUrl) return;
23
+ const ctx = new AudioContext();
24
+ fetch(audioUrl)
25
+ .then((r) => r.arrayBuffer())
26
+ .then((buf) => ctx.decodeAudioData(buf))
27
+ .then((decoded) => {
28
+ const raw = decoded.getChannelData(0);
29
+ const bars = 100;
30
+ const blockSize = Math.floor(raw.length / bars);
31
+ const peaks: number[] = [];
32
+ for (let i = 0; i < bars; i++) {
33
+ let sum = 0;
34
+ for (let j = 0; j < blockSize; j++) {
35
+ sum += Math.abs(raw[i * blockSize + j]);
36
+ }
37
+ peaks.push(sum / blockSize);
38
+ }
39
+ // Normalize
40
+ const max = Math.max(...peaks, 0.01);
41
+ waveformRef.current = peaks.map((p) => p / max);
42
+ drawWaveform();
43
+ ctx.close();
44
+ })
45
+ .catch(() => {});
46
+ }, [audioUrl]);
47
+
48
+ const drawWaveform = useCallback(() => {
49
+ const canvas = canvasRef.current;
50
+ if (!canvas) return;
51
+ const ctx = canvas.getContext("2d");
52
+ if (!ctx) return;
53
+
54
+ const dpr = window.devicePixelRatio || 1;
55
+ const rect = canvas.getBoundingClientRect();
56
+ canvas.width = rect.width * dpr;
57
+ canvas.height = rect.height * dpr;
58
+ ctx.scale(dpr, dpr);
59
+
60
+ const w = rect.width;
61
+ const h = rect.height;
62
+ const peaks = waveformRef.current;
63
+ const bars = peaks.length || 1;
64
+ const audio = audioRef.current;
65
+ const progress = audio && audio.duration ? audio.currentTime / audio.duration : 0;
66
+
67
+ ctx.clearRect(0, 0, w, h);
68
+
69
+ const barWidth = (w / bars) * 0.7;
70
+ const gap = (w / bars) * 0.3;
71
+ const mid = h / 2;
72
+
73
+ for (let i = 0; i < bars; i++) {
74
+ const x = (i / bars) * w;
75
+ const barH = Math.max(2, (peaks[i] || 0) * mid * 0.9);
76
+ const iPlayed = i / bars < progress;
77
+
78
+ ctx.fillStyle = iPlayed ? "#c084fc" : hovering && x < hoverX ? "rgba(192,132,252,0.4)" : "#444";
79
+ ctx.beginPath();
80
+ ctx.roundRect(x + gap / 2, mid - barH, barWidth, barH * 2, 1.5);
81
+ ctx.fill();
82
+ }
83
+ }, [hovering, hoverX]);
84
+
85
+ // Animation loop
86
+ useEffect(() => {
87
+ const tick = () => {
88
+ const audio = audioRef.current;
89
+ if (audio) setCurrTime(audio.currentTime);
90
+ drawWaveform();
91
+ animRef.current = requestAnimationFrame(tick);
92
+ };
93
+ animRef.current = requestAnimationFrame(tick);
94
+ return () => cancelAnimationFrame(animRef.current);
95
+ }, [drawWaveform]);
96
+
97
+ const togglePlay = () => {
98
+ const audio = audioRef.current;
99
+ if (!audio) return;
100
+ if (audio.paused) {
101
+ audio.play();
102
+ setPlaying(true);
103
+ } else {
104
+ audio.pause();
105
+ setPlaying(false);
106
+ }
107
+ };
108
+
109
+ const seek = (e: React.MouseEvent<HTMLCanvasElement>) => {
110
+ const audio = audioRef.current;
111
+ const canvas = canvasRef.current;
112
+ if (!audio || !canvas || !audio.duration) return;
113
+ const rect = canvas.getBoundingClientRect();
114
+ const ratio = (e.clientX - rect.left) / rect.width;
115
+ audio.currentTime = ratio * audio.duration;
116
+ };
117
+
118
+ const handleMouseMove = (e: React.MouseEvent<HTMLCanvasElement>) => {
119
+ const canvas = canvasRef.current;
120
+ if (!canvas) return;
121
+ const rect = canvas.getBoundingClientRect();
122
+ setHoverX(e.clientX - rect.left);
123
+ };
124
+
125
+ const fmt = (s: number) => {
126
+ const m = Math.floor(s / 60);
127
+ const sec = Math.floor(s % 60);
128
+ return `${m}:${sec.toString().padStart(2, "0")}`;
129
+ };
130
+
131
+ return (
132
+ <div className="waveform-player">
133
+ <audio
134
+ ref={audioRef}
135
+ src={audioUrl}
136
+ onLoadedMetadata={() => setTotalDuration(audioRef.current?.duration || 0)}
137
+ onEnded={() => setPlaying(false)}
138
+ />
139
+
140
+ <button className="waveform-play" onClick={togglePlay} aria-label={playing ? "Pause" : "Play"}>
141
+ {playing ? (
142
+ <svg width="16" height="16" viewBox="0 0 16 16" fill="currentColor">
143
+ <rect x="3" y="2" width="4" height="12" rx="1" />
144
+ <rect x="9" y="2" width="4" height="12" rx="1" />
145
+ </svg>
146
+ ) : (
147
+ <svg width="16" height="16" viewBox="0 0 16 16" fill="currentColor">
148
+ <path d="M4 2.5v11l9-5.5z" />
149
+ </svg>
150
+ )}
151
+ </button>
152
+
153
+ <canvas
154
+ ref={canvasRef}
155
+ className="waveform-canvas"
156
+ onClick={seek}
157
+ onMouseEnter={() => setHovering(true)}
158
+ onMouseLeave={() => setHovering(false)}
159
+ onMouseMove={handleMouseMove}
160
+ />
161
+
162
+ <span className="waveform-time">
163
+ {fmt(currentTime)} / {fmt(totalDuration)}
164
+ </span>
165
+
166
+ {duration !== null && duration !== undefined && (
167
+ <span className="waveform-gen-time">
168
+ {(duration / 1000).toFixed(1)}s
169
+ </span>
170
+ )}
171
+ </div>
172
+ );
173
+ }
src/index.css CHANGED
@@ -216,21 +216,57 @@ input[type="range"] {
216
  padding: 1rem;
217
  }
218
 
219
- .audio-result {
 
 
220
  display: flex;
221
- flex-direction: column;
222
- gap: 0.5rem;
223
  }
224
 
225
- .audio-player {
226
- width: 100%;
227
- border-radius: var(--radius);
 
 
 
 
 
 
 
 
 
 
228
  }
229
 
230
- .duration {
231
- font-size: 0.75rem;
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  color: var(--text-muted);
 
 
 
 
 
 
233
  font-family: var(--mono);
 
 
 
 
 
234
  }
235
 
236
  .audio-placeholder {
@@ -287,6 +323,14 @@ input[type="range"] {
287
 
288
  .example-text {
289
  color: var(--text-muted);
 
 
 
 
 
 
 
 
290
  }
291
 
292
  /* Error */
 
216
  padding: 1rem;
217
  }
218
 
219
+ /* Waveform player */
220
+
221
+ .waveform-player {
222
  display: flex;
223
+ align-items: center;
224
+ gap: 0.75rem;
225
  }
226
 
227
+ .waveform-play {
228
+ flex-shrink: 0;
229
+ width: 36px;
230
+ height: 36px;
231
+ border-radius: 50%;
232
+ border: none;
233
+ background: var(--accent);
234
+ color: #111;
235
+ cursor: pointer;
236
+ display: flex;
237
+ align-items: center;
238
+ justify-content: center;
239
+ transition: opacity 0.15s;
240
  }
241
 
242
+ .waveform-play:hover {
243
+ opacity: 0.85;
244
+ }
245
+
246
+ .waveform-canvas {
247
+ flex: 1;
248
+ height: 48px;
249
+ cursor: pointer;
250
+ border-radius: 4px;
251
+ }
252
+
253
+ .waveform-time {
254
+ flex-shrink: 0;
255
+ font-family: var(--mono);
256
+ font-size: 0.7rem;
257
  color: var(--text-muted);
258
+ min-width: 5.5em;
259
+ text-align: right;
260
+ }
261
+
262
+ .waveform-gen-time {
263
+ flex-shrink: 0;
264
  font-family: var(--mono);
265
+ font-size: 0.65rem;
266
+ color: #555;
267
+ padding: 0.15rem 0.4rem;
268
+ background: var(--surface-2);
269
+ border-radius: 4px;
270
  }
271
 
272
  .audio-placeholder {
 
323
 
324
  .example-text {
325
  color: var(--text-muted);
326
+ flex: 1;
327
+ }
328
+
329
+ .example-meta {
330
+ flex-shrink: 0;
331
+ font-family: var(--mono);
332
+ font-size: 0.7rem;
333
+ color: #555;
334
  }
335
 
336
  /* Error */
src/worker.ts CHANGED
@@ -16,8 +16,8 @@ let ort: any;
16
  const HF_BASE = "https://huggingface.co";
17
  const SAMPLE_RATE = 24000;
18
 
19
- // Int8 quantized models produce NaN on WebGPU; all fp32 models should be fine
20
- const WEBGPU_BLOCKED_PATTERNS = ["int8"];
21
 
22
  interface ModelConfig {
23
  name: string;
@@ -63,19 +63,22 @@ async function loadModel(repoId: string) {
63
  ort = ortModule;
64
  phonemize = phonemizerModule.phonemize;
65
 
66
- // Load config
67
  self.postMessage({ type: "status", message: "Loading config..." });
68
- const configUrl = resolveUrl(repoId, "config.json");
69
- const configResp = await fetch(configUrl);
 
 
 
70
  config = (await configResp.json()) as ModelConfig;
71
 
72
- // Int8 quantized models produce NaN on WebGPU only block those
73
  const modelName = config.model || repoId.split("/").pop() || "";
74
- const isBlocked = WEBGPU_BLOCKED_PATTERNS.some((p) => modelName.includes(p));
75
- currentDevice = hasWebGPU && !isBlocked ? "webgpu" : "wasm";
76
 
77
- if (hasWebGPU && isBlocked) {
78
- console.log(`[KittenTTS] Using WASM for "${modelName}" (int8 models produce NaN on WebGPU)`);
79
  }
80
 
81
  self.postMessage({ type: "device", device: currentDevice });
@@ -83,7 +86,9 @@ async function loadModel(repoId: string) {
83
  // Load voices (.npz) and ONNX model in parallel
84
  self.postMessage({ type: "status", message: "Downloading model & voices..." });
85
 
86
- const modelUrl = resolveUrl(repoId, config.model_file);
 
 
87
 
88
  const modelPromise = (async () => {
89
  const resp = await fetch(modelUrl);
@@ -244,6 +249,12 @@ async function generateChunk(
244
  const outputKey = session.outputNames[0];
245
  const audioData = results[outputKey].data as Float32Array;
246
 
 
 
 
 
 
 
247
  // Trim trailing silence (matching Python: audio[..., :-5000])
248
  return audioData.slice(0, Math.max(0, audioData.length - 5000));
249
  }
 
16
  const HF_BASE = "https://huggingface.co";
17
  const SAMPLE_RATE = 24000;
18
 
19
+ // Only nano (fp32) confirmed working on WebGPU; micro/mini are int8 quantized
20
+ const WEBGPU_SAFE_MODELS = ["Nano", "nano", "fp32"];
21
 
22
  interface ModelConfig {
23
  name: string;
 
63
  ort = ortModule;
64
  phonemize = phonemizerModule.phonemize;
65
 
66
+ // Load config (onnx-community repos use kitten_config.json for the TTS config)
67
  self.postMessage({ type: "status", message: "Loading config..." });
68
+ let configResp = await fetch(resolveUrl(repoId, "kitten_config.json"));
69
+ if (!configResp.ok) {
70
+ // Fallback to config.json for original KittenML repos
71
+ configResp = await fetch(resolveUrl(repoId, "config.json"));
72
+ }
73
  config = (await configResp.json()) as ModelConfig;
74
 
75
+ // Only use WebGPU for models confirmed to work (nano-fp32)
76
  const modelName = config.model || repoId.split("/").pop() || "";
77
+ const isSafe = WEBGPU_SAFE_MODELS.some((m) => modelName.includes(m));
78
+ currentDevice = hasWebGPU && isSafe ? "webgpu" : "wasm";
79
 
80
+ if (hasWebGPU && !isSafe) {
81
+ console.log(`[KittenTTS] Using WASM for "${modelName}" (WebGPU only confirmed for nano-fp32)`);
82
  }
83
 
84
  self.postMessage({ type: "device", device: currentDevice });
 
86
  // Load voices (.npz) and ONNX model in parallel
87
  self.postMessage({ type: "status", message: "Downloading model & voices..." });
88
 
89
+ // onnx-community repos have model at onnx/model.onnx, original repos use config.model_file
90
+ const modelFile = config.model_file || "onnx/model.onnx";
91
+ const modelUrl = resolveUrl(repoId, modelFile);
92
 
93
  const modelPromise = (async () => {
94
  const resp = await fetch(modelUrl);
 
249
  const outputKey = session.outputNames[0];
250
  const audioData = results[outputKey].data as Float32Array;
251
 
252
+ // Check for NaN — if detected, the model doesn't work on this backend
253
+ const hasNaN = audioData.length > 0 && isNaN(audioData[0]);
254
+ if (hasNaN) {
255
+ console.warn(`[KittenTTS] Model produced NaN audio — this model may not be compatible with ${currentDevice.toUpperCase()}`);
256
+ }
257
+
258
  // Trim trailing silence (matching Python: audio[..., :-5000])
259
  return audioData.slice(0, Math.max(0, audioData.length - 5000));
260
  }