Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Form, UploadFile, File | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| from enum import Enum | |
| from typing import Literal | |
| from datetime import datetime | |
| import tempfile | |
| import wave | |
| import io | |
| import os | |
| import re | |
| from piper import PiperVoice, SynthesisConfig | |
| from faster_whisper import WhisperModel | |
| # -------------------- CONFIG -------------------- | |
| VOICE_DIR = "actors" | |
| WHISPER_MODEL_PATH = "./models/faster-whisper-tiny" | |
| syn_config = SynthesisConfig( | |
| volume=1.0, | |
| length_scale=1.15, | |
| noise_scale=0.55, | |
| noise_w_scale=0.7, | |
| normalize_audio=True, | |
| ) | |
| # -------------------- ENUMS -------------------- | |
| class VoiceActor(str, Enum): | |
| alba = "en_GB-alba-medium.onnx" | |
| hfc_female = "en_US-hfc_female-medium.onnx" | |
| danny = "en_US-danny-low.onnx" | |
| lessac = "en_US-lessac-high.onnx" | |
| libritts = "en_US-libritts-high.onnx" | |
| cori = "en_GB-cori-high.onnx" | |
| class Input(BaseModel): | |
| text: str | |
| actor: Literal[ | |
| "en_GB-alba-medium.onnx", | |
| "en_US-hfc_female-medium.onnx", | |
| "en_US-danny-low.onnx", | |
| "en_US-lessac-high.onnx", | |
| ] | None = VoiceActor.alba.value | |
| # -------------------- APP -------------------- | |
| app = FastAPI(title="Fast TTS + STT API") | |
| # -------------------- MODEL CACHE -------------------- | |
| print("🔹 Loading Whisper model...") | |
| stt_model = WhisperModel( | |
| WHISPER_MODEL_PATH, | |
| device="cpu", | |
| compute_type="int8", | |
| cpu_threads=os.cpu_count(), | |
| num_workers=1, | |
| ) | |
| print("🔹 Whisper loaded") | |
| voice_cache: dict[str, PiperVoice] = {} | |
| def get_voice(actor: str) -> PiperVoice: | |
| if actor not in voice_cache: | |
| voice_cache[actor] = PiperVoice.load(f"{VOICE_DIR}/{actor}") | |
| return voice_cache[actor] | |
| def chunk_text(text: str, max_tokens: int = 150): | |
| sentences = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| chunks = [] | |
| current = [] | |
| for sentence in sentences: | |
| words = sentence.split() | |
| if len(current) + len(words) <= max_tokens: | |
| current.extend(words) | |
| else: | |
| chunks.append(" ".join(current)) | |
| current = words | |
| if current: | |
| chunks.append(" ".join(current)) | |
| return chunks | |
| def synthesize_chunked_tts(text: str, voice, syn_config): | |
| chunks = chunk_text(text, max_tokens=150) | |
| output = io.BytesIO() | |
| sample_rate = voice.config.sample_rate | |
| with wave.open(output, "wb") as out_wav: | |
| out_wav.setnchannels(1) | |
| out_wav.setsampwidth(2) | |
| out_wav.setframerate(sample_rate) | |
| for chunk in chunks: | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as temp_wav: | |
| temp_wav.setnchannels(1) | |
| temp_wav.setsampwidth(2) | |
| temp_wav.setframerate(sample_rate) | |
| voice.synthesize_wav( | |
| chunk, | |
| temp_wav, | |
| syn_config=syn_config | |
| ) | |
| buffer.seek(0) | |
| with wave.open(buffer, "rb") as temp_wav: | |
| out_wav.writeframes(temp_wav.readframes(temp_wav.getnframes())) | |
| output.seek(0) | |
| return output | |
| # -------------------- ROUTES -------------------- | |
| def root(): | |
| return {"status": "ok"} | |
| # -------- TTS (JSON, returns file) -------- | |
| def tts_demo(input: Input): | |
| voice = get_voice(input.actor) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| temp_path = temp_file.name | |
| temp_file.close() | |
| with wave.open(temp_path, "wb") as wav: | |
| wav.setnchannels(1) | |
| wav.setsampwidth(2) | |
| wav.setframerate(voice.config.sample_rate) | |
| voice.synthesize_wav(input.text, wav, syn_config=syn_config) | |
| return FileResponse( | |
| temp_path, | |
| filename=f"tts-{int(datetime.now().timestamp())}.wav", | |
| media_type="audio/wav", | |
| ) | |
| # -------- TTS (FORM, STREAMING – FASTEST) -------- | |
| def tts( | |
| text: str = Form(...), | |
| actor: VoiceActor = Form(VoiceActor.alba), | |
| ): | |
| voice = get_voice(actor.value) | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wav: | |
| wav.setnchannels(1) | |
| wav.setsampwidth(2) | |
| wav.setframerate(voice.config.sample_rate) | |
| voice.synthesize_wav(text, wav, syn_config=syn_config) | |
| buffer.seek(0) | |
| return StreamingResponse(buffer, media_type="audio/wav") | |
| # -------- STT ONLY -------- | |
| async def speech_to_text(file: UploadFile = File(...)): | |
| audio_bytes = await file.read() | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| f.write(audio_bytes) | |
| temp_path = f.name | |
| segments, info = stt_model.transcribe( | |
| temp_path, | |
| beam_size=1, | |
| language="en", | |
| vad_filter=True, | |
| vad_parameters={"min_silence_duration_ms": 500}, | |
| ) | |
| os.unlink(temp_path) | |
| return { | |
| "text": " ".join(seg.text for seg in segments), | |
| "language": info.language, | |
| "duration": info.duration, | |
| } | |
| def tts( | |
| text: str = Form(...), | |
| actor: VoiceActor = Form(VoiceActor.alba), | |
| ): | |
| voice = get_voice(actor.value) | |
| audio_buffer = synthesize_chunked_tts( | |
| text=text, | |
| voice=voice, | |
| syn_config=syn_config, | |
| ) | |
| return StreamingResponse(audio_buffer, media_type="audio/wav") | |
| # -------- STT → TTS -------- | |
| async def convert(file: UploadFile = File(...)): | |
| audio_bytes = await file.read() | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| f.write(audio_bytes) | |
| temp_path = f.name | |
| segments, _ = stt_model.transcribe( | |
| temp_path, | |
| beam_size=1, | |
| language="en", | |
| vad_filter=True, | |
| vad_parameters={"min_silence_duration_ms": 500}, | |
| ) | |
| os.unlink(temp_path) | |
| text = " ".join(seg.text for seg in segments) | |
| voice = get_voice(VoiceActor.alba.value) | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wav: | |
| wav.setnchannels(1) | |
| wav.setsampwidth(2) | |
| wav.setframerate(voice.config.sample_rate) | |
| voice.synthesize_wav(text, wav, syn_config=syn_config) | |
| buffer.seek(0) | |
| return StreamingResponse(buffer, media_type="audio/wav") | |