| import os |
| import uvicorn |
| import sys |
| import secrets |
| import json |
| import logging |
| from contextlib import asynccontextmanager |
| from typing import Optional, Dict |
|
|
| from fastapi import FastAPI, HTTPException, Security, status, Depends |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from pydantic import BaseModel |
|
|
| |
| import supertonic_model |
| import kokoro_model |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| handlers=[logging.StreamHandler()] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
|
|
| |
| MODEL_FACTORIES = { |
| "supertonic": supertonic_model.StreamingEngine, |
| "kokoro": kokoro_model.StreamingEngine |
| } |
|
|
| |
| engines: Dict[str, object] = {} |
|
|
| |
| |
| |
| security = HTTPBearer() |
|
|
| async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)): |
| server_key = os.getenv("API_KEY") |
| |
| if not server_key: |
| |
| return True |
|
|
| client_key = credentials.credentials |
| if not secrets.compare_digest(server_key, client_key): |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="Invalid API Key", |
| headers={"WWW-Authenticate": "Bearer"}, |
| ) |
| return True |
|
|
| |
| |
| |
| class SpeechRequest(BaseModel): |
| model: Optional[str] = "tts-1" |
| input: str |
| voice: str = "alloy" |
| format: Optional[str] = "mp3" |
| speed: Optional[float] = 1.0 |
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global engines |
|
|
| |
| if not os.getenv("API_KEY"): |
| logger.warning("API_KEY not set. API is open to the public.") |
| else: |
| logger.info("Secure Mode: API Key protection enabled.") |
|
|
| |
| models_env = os.getenv("MODELS") |
| if not models_env: |
| logger.error("MODELS environment variable not set. Exiting.") |
| sys.exit(1) |
|
|
| try: |
| |
| models_config = json.loads(models_env) |
| except json.JSONDecodeError as e: |
| logger.error(f"Failed to parse MODELS JSON: {e}") |
| sys.exit(1) |
|
|
| |
| logger.info(f"Loading models configuration: {models_config}") |
| |
| for model_id, backend_type in models_config.items(): |
| if backend_type not in MODEL_FACTORIES: |
| logger.error(f"Unknown backend type '{backend_type}' for model '{model_id}'") |
| continue |
|
|
| try: |
| logger.info(f"Initializing {model_id} -> {backend_type}...") |
| engine_class = MODEL_FACTORIES[backend_type] |
| engines[model_id] = engine_class(f"{model_id}-->{backend_type}") |
| except Exception as e: |
| logger.error(f"Failed to load {model_id}: {e}") |
| |
|
|
| if not engines: |
| logger.error("No engines loaded successfully. Exiting.") |
| sys.exit(1) |
|
|
| yield |
| |
| |
| engines.clear() |
|
|
| app = FastAPI(lifespan=lifespan, title="Streaming TTS API") |
|
|
| |
| |
| |
|
|
| @app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)]) |
| async def text_to_speech(request: SpeechRequest): |
| global engines |
| |
| if not engines: |
| raise HTTPException(status_code=500, detail="No TTS engines loaded") |
|
|
| |
| if request.model not in engines: |
| valid_models = list(engines.keys()) |
| return JSONResponse( |
| status_code=404, |
| content={ |
| "error": { |
| "message": f"Model '{request.model}' not found. Available: {valid_models}", |
| "type": "invalid_request_error", |
| "code": "model_not_found" |
| } |
| } |
| ) |
|
|
| |
| audio_format = request.format if request.format else "mp3" |
| if audio_format not in ["wav", "mp3"]: |
| audio_format = "wav" |
|
|
| logger.info(f"Generating: model={request.model} voice={request.voice} fmt={audio_format} len={len(request.input)}") |
|
|
| try: |
| generator = engines[request.model].stream_generator( |
| request.input, |
| request.voice, |
| request.speed, |
| audio_format |
| ) |
| |
| return StreamingResponse( |
| generator, |
| media_type=f"audio/{audio_format}" |
| ) |
| except Exception as e: |
| logger.error(f"Generation failed: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/v1/models", dependencies=[Depends(verify_api_key)]) |
| async def list_models(): |
| """ |
| Returns the list of currently loaded models dynamically. |
| """ |
| model_list = [] |
| for model_id, engine_inst in engines.items(): |
| |
| owned_by = getattr(engine_inst, "name", "system") |
| model_list.append({ |
| "id": model_id, |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": owned_by |
| }) |
| |
| return {"object": "list", "data": model_list} |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--host", default="0.0.0.0") |
| parser.add_argument("--port", type=int, default=8000) |
| args = parser.parse_args() |
|
|
| uvicorn.run(app, host=args.host, port=args.port) |