speech / server.py
harsh-dev's picture
Add application file
12bc4c0
from fastapi import FastAPI, Form, UploadFile, File
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from enum import Enum
from typing import Literal
from datetime import datetime
import tempfile
import wave
import io
import os
import re
from piper import PiperVoice, SynthesisConfig
from faster_whisper import WhisperModel
# -------------------- CONFIG --------------------
VOICE_DIR = "actors"
WHISPER_MODEL_PATH = "./models/faster-whisper-tiny"
syn_config = SynthesisConfig(
volume=1.0,
length_scale=1.15,
noise_scale=0.55,
noise_w_scale=0.7,
normalize_audio=True,
)
# -------------------- ENUMS --------------------
class VoiceActor(str, Enum):
alba = "en_GB-alba-medium.onnx"
hfc_female = "en_US-hfc_female-medium.onnx"
danny = "en_US-danny-low.onnx"
lessac = "en_US-lessac-high.onnx"
libritts = "en_US-libritts-high.onnx"
cori = "en_GB-cori-high.onnx"
class Input(BaseModel):
text: str
actor: Literal[
"en_GB-alba-medium.onnx",
"en_US-hfc_female-medium.onnx",
"en_US-danny-low.onnx",
"en_US-lessac-high.onnx",
] | None = VoiceActor.alba.value
# -------------------- APP --------------------
app = FastAPI(title="Fast TTS + STT API")
# -------------------- MODEL CACHE --------------------
print("🔹 Loading Whisper model...")
stt_model = WhisperModel(
WHISPER_MODEL_PATH,
device="cpu",
compute_type="int8",
cpu_threads=os.cpu_count(),
num_workers=1,
)
print("🔹 Whisper loaded")
voice_cache: dict[str, PiperVoice] = {}
def get_voice(actor: str) -> PiperVoice:
if actor not in voice_cache:
voice_cache[actor] = PiperVoice.load(f"{VOICE_DIR}/{actor}")
return voice_cache[actor]
def chunk_text(text: str, max_tokens: int = 150):
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
chunks = []
current = []
for sentence in sentences:
words = sentence.split()
if len(current) + len(words) <= max_tokens:
current.extend(words)
else:
chunks.append(" ".join(current))
current = words
if current:
chunks.append(" ".join(current))
return chunks
def synthesize_chunked_tts(text: str, voice, syn_config):
chunks = chunk_text(text, max_tokens=150)
output = io.BytesIO()
sample_rate = voice.config.sample_rate
with wave.open(output, "wb") as out_wav:
out_wav.setnchannels(1)
out_wav.setsampwidth(2)
out_wav.setframerate(sample_rate)
for chunk in chunks:
buffer = io.BytesIO()
with wave.open(buffer, "wb") as temp_wav:
temp_wav.setnchannels(1)
temp_wav.setsampwidth(2)
temp_wav.setframerate(sample_rate)
voice.synthesize_wav(
chunk,
temp_wav,
syn_config=syn_config
)
buffer.seek(0)
with wave.open(buffer, "rb") as temp_wav:
out_wav.writeframes(temp_wav.readframes(temp_wav.getnframes()))
output.seek(0)
return output
# -------------------- ROUTES --------------------
@app.get("/")
def root():
return {"status": "ok"}
# -------- TTS (JSON, returns file) --------
@app.post("/tts-demo")
def tts_demo(input: Input):
voice = get_voice(input.actor)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_path = temp_file.name
temp_file.close()
with wave.open(temp_path, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(voice.config.sample_rate)
voice.synthesize_wav(input.text, wav, syn_config=syn_config)
return FileResponse(
temp_path,
filename=f"tts-{int(datetime.now().timestamp())}.wav",
media_type="audio/wav",
)
# -------- TTS (FORM, STREAMING – FASTEST) --------
@app.post("/tts")
def tts(
text: str = Form(...),
actor: VoiceActor = Form(VoiceActor.alba),
):
voice = get_voice(actor.value)
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(voice.config.sample_rate)
voice.synthesize_wav(text, wav, syn_config=syn_config)
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")
# -------- STT ONLY --------
@app.post("/stt")
async def speech_to_text(file: UploadFile = File(...)):
audio_bytes = await file.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_bytes)
temp_path = f.name
segments, info = stt_model.transcribe(
temp_path,
beam_size=1,
language="en",
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
os.unlink(temp_path)
return {
"text": " ".join(seg.text for seg in segments),
"language": info.language,
"duration": info.duration,
}
@app.post("/speech")
def tts(
text: str = Form(...),
actor: VoiceActor = Form(VoiceActor.alba),
):
voice = get_voice(actor.value)
audio_buffer = synthesize_chunked_tts(
text=text,
voice=voice,
syn_config=syn_config,
)
return StreamingResponse(audio_buffer, media_type="audio/wav")
# -------- STT → TTS --------
@app.post("/convert")
async def convert(file: UploadFile = File(...)):
audio_bytes = await file.read()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_bytes)
temp_path = f.name
segments, _ = stt_model.transcribe(
temp_path,
beam_size=1,
language="en",
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
os.unlink(temp_path)
text = " ".join(seg.text for seg in segments)
voice = get_voice(VoiceActor.alba.value)
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(voice.config.sample_rate)
voice.synthesize_wav(text, wav, syn_config=syn_config)
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")