from fastapi import FastAPI, Form, UploadFile, File from fastapi.responses import StreamingResponse, FileResponse from pydantic import BaseModel from enum import Enum from typing import Literal from datetime import datetime import tempfile import wave import io import os import re from piper import PiperVoice, SynthesisConfig from faster_whisper import WhisperModel # -------------------- PATHS -------------------- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) VOICE_DIR = os.path.join(BASE_DIR, "actors") WHISPER_MODEL_PATH = os.path.join(BASE_DIR, "models", "faster-whisper-tiny") # -------------------- CONFIG -------------------- syn_config = SynthesisConfig( volume=1.0, length_scale=1.15, noise_scale=0.55, noise_w_scale=0.7, normalize_audio=True, ) # -------------------- ENUMS -------------------- class VoiceActor(str, Enum): alba = "en_GB-alba-medium.onnx" hfc_female = "en_US-hfc_female-medium.onnx" danny = "en_US-danny-low.onnx" lessac = "en_US-lessac-high.onnx" libritts = "en_US-libritts-high.onnx" cori = "en_GB-cori-high.onnx" class Input(BaseModel): text: str actor: Literal[ "en_GB-alba-medium.onnx", "en_US-hfc_female-medium.onnx", "en_US-danny-low.onnx", "en_US-lessac-high.onnx", ] | None = VoiceActor.alba.value # -------------------- APP -------------------- app = FastAPI(title="Fast TTS + STT API") # -------------------- MODEL CACHE -------------------- print("🔹 Loading Whisper model...") stt_model = WhisperModel( WHISPER_MODEL_PATH, device="cpu", compute_type="int8", cpu_threads=os.cpu_count(), ) print("🔹 Whisper loaded") voice_cache: dict[str, PiperVoice] = {} def get_voice(actor: str) -> PiperVoice: if actor not in voice_cache: model_path = os.path.join(VOICE_DIR, actor) voice_cache[actor] = PiperVoice.load(model_path) return voice_cache[actor] # -------------------- UTILS -------------------- def chunk_text(text: str, max_tokens: int = 150): sentences = re.split(r'(?<=[.!?])\s+', text.strip()) chunks = [] current = [] for sentence in sentences: words = sentence.split() if len(current) + len(words) <= max_tokens: current.extend(words) else: chunks.append(" ".join(current)) current = words if current: chunks.append(" ".join(current)) return chunks def synthesize_chunked_tts(text: str, voice, syn_config): chunks = chunk_text(text) output = io.BytesIO() sample_rate = voice.config.sample_rate with wave.open(output, "wb") as out_wav: out_wav.setnchannels(1) out_wav.setsampwidth(2) out_wav.setframerate(sample_rate) for chunk in chunks: buffer = io.BytesIO() with wave.open(buffer, "wb") as temp_wav: temp_wav.setnchannels(1) temp_wav.setsampwidth(2) temp_wav.setframerate(sample_rate) voice.synthesize_wav(chunk, temp_wav, syn_config=syn_config) buffer.seek(0) with wave.open(buffer, "rb") as temp_wav: out_wav.writeframes( temp_wav.readframes(temp_wav.getnframes()) ) output.seek(0) return output # -------------------- ROUTES -------------------- @app.get("/") def root(): return {"status": "ok"} @app.post("/tts") def tts( text: str = Form(...), actor: VoiceActor = Form(VoiceActor.alba), ): voice = get_voice(actor.value) buffer = io.BytesIO() with wave.open(buffer, "wb") as wav: wav.setnchannels(1) wav.setsampwidth(2) wav.setframerate(voice.config.sample_rate) voice.synthesize_wav(text, wav, syn_config=syn_config) buffer.seek(0) return StreamingResponse(buffer, media_type="audio/wav") @app.post("/speech") def tts_chunked( text: str = Form(...), actor: VoiceActor = Form(VoiceActor.alba), ): voice = get_voice(actor.value) audio_buffer = synthesize_chunked_tts(text, voice, syn_config) return StreamingResponse(audio_buffer, media_type="audio/wav") @app.post("/stt") async def speech_to_text(file: UploadFile = File(...)): audio_bytes = await file.read() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: f.write(audio_bytes) temp_path = f.name segments, info = stt_model.transcribe( temp_path, beam_size=1, language="en", vad_filter=True, vad_parameters={"min_silence_duration_ms": 500}, ) os.unlink(temp_path) return { "text": " ".join(seg.text for seg in segments), "language": info.language, "duration": info.duration, } @app.post("/convert") async def stt_to_tts(file: UploadFile = File(...)): audio_bytes = await file.read() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: f.write(audio_bytes) temp_path = f.name segments, _ = stt_model.transcribe(temp_path) os.unlink(temp_path) text = " ".join(seg.text for seg in segments) voice = get_voice(VoiceActor.alba.value) buffer = io.BytesIO() with wave.open(buffer, "wb") as wav: wav.setnchannels(1) wav.setsampwidth(2) wav.setframerate(voice.config.sample_rate) voice.synthesize_wav(text, wav, syn_config=syn_config) buffer.seek(0) return StreamingResponse(buffer, media_type="audio/wav")