speech / app.py
harsh-dev's picture
Add application file
12bc4c0
from fastapi import FastAPI, Form, UploadFile, File
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
from enum import Enum
from typing import Literal
from datetime import datetime
import tempfile
import wave
import io
import os
import re
from piper import PiperVoice, SynthesisConfig
from faster_whisper import WhisperModel
# -------------------- PATHS --------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
VOICE_DIR = os.path.join(BASE_DIR, "actors")
WHISPER_MODEL_PATH = os.path.join(BASE_DIR, "models", "faster-whisper-tiny")
# -------------------- CONFIG --------------------
syn_config = SynthesisConfig(
volume=1.0,
length_scale=1.15,
noise_scale=0.55,
noise_w_scale=0.7,
normalize_audio=True,
)
# -------------------- ENUMS --------------------
class VoiceActor(str, Enum):
alba = "en_GB-alba-medium.onnx"
hfc_female = "en_US-hfc_female-medium.onnx"
danny = "en_US-danny-low.onnx"
lessac = "en_US-lessac-high.onnx"
libritts = "en_US-libritts-high.onnx"
cori = "en_GB-cori-high.onnx"
class Input(BaseModel):
text: str
actor: Literal[
"en_GB-alba-medium.onnx",
"en_US-hfc_female-medium.onnx",
"en_US-danny-low.onnx",
"en_US-lessac-high.onnx",
] | None = VoiceActor.alba.value
# -------------------- APP --------------------
app = FastAPI(title="Fast TTS + STT API")
# -------------------- MODEL CACHE --------------------
print("🔹 Loading Whisper model...")
stt_model = WhisperModel(
WHISPER_MODEL_PATH,
device="cpu",
compute_type="int8",
cpu_threads=os.cpu_count(),
)
print("🔹 Whisper loaded")
voice_cache: dict[str, PiperVoice] = {}
def get_voice(actor: str) -> PiperVoice:
if actor not in voice_cache:
model_path = os.path.join(VOICE_DIR, actor)
voice_cache[actor] = PiperVoice.load(model_path)
return voice_cache[actor]
# -------------------- UTILS --------------------
def chunk_text(text: str, max_tokens: int = 150):
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
chunks = []
current = []
for sentence in sentences:
words = sentence.split()
if len(current) + len(words) <= max_tokens:
current.extend(words)
else:
chunks.append(" ".join(current))
current = words
if current:
chunks.append(" ".join(current))
return chunks
def synthesize_chunked_tts(text: str, voice, syn_config):
chunks = chunk_text(text)
output = io.BytesIO()
sample_rate = voice.config.sample_rate
with wave.open(output, "wb") as out_wav:
out_wav.setnchannels(1)
out_wav.setsampwidth(2)
out_wav.setframerate(sample_rate)
for chunk in chunks:
buffer = io.BytesIO()
with wave.open(buffer, "wb") as temp_wav:
temp_wav.setnchannels(1)
temp_wav.setsampwidth(2)
temp_wav.setframerate(sample_rate)
voice.synthesize_wav(chunk, temp_wav, syn_config=syn_config)
buffer.seek(0)
with wave.open(buffer, "rb") as temp_wav:
out_wav.writeframes(
temp_wav.readframes(temp_wav.getnframes())
)
output.seek(0)
return output
# -------------------- ROUTES --------------------
@app.get("/")
def root():
return {"status": "ok"}
@app.post("/tts")
def tts(
text: str = Form(...),
actor: VoiceActor = Form(VoiceActor.alba),
):
voice = get_voice(actor.value)
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(voice.config.sample_rate)
voice.synthesize_wav(text, wav, syn_config=syn_config)
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")
@app.post("/speech")
def tts_chunked(
text: str = Form(...),
actor: VoiceActor = Form(VoiceActor.alba),
):
voice = get_voice(actor.value)
audio_buffer = synthesize_chunked_tts(text, voice, syn_config)
return StreamingResponse(audio_buffer, media_type="audio/wav")
@app.post("/stt")
async def speech_to_text(file: UploadFile = File(...)):
audio_bytes = await file.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_bytes)
temp_path = f.name
segments, info = stt_model.transcribe(
temp_path,
beam_size=1,
language="en",
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
os.unlink(temp_path)
return {
"text": " ".join(seg.text for seg in segments),
"language": info.language,
"duration": info.duration,
}
@app.post("/convert")
async def stt_to_tts(file: UploadFile = File(...)):
audio_bytes = await file.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_bytes)
temp_path = f.name
segments, _ = stt_model.transcribe(temp_path)
os.unlink(temp_path)
text = " ".join(seg.text for seg in segments)
voice = get_voice(VoiceActor.alba.value)
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(voice.config.sample_rate)
voice.synthesize_wav(text, wav, syn_config=syn_config)
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")