| """ |
| Kokoro-TTS Local Generator |
| ------------------------- |
| A high-performance text-to-speech system with both Gradio UI and REST API support. |
| Provides multiple voice models, audio formats, and cross-platform compatibility. |
| |
| Key Features: |
| - Multiple voice models support (26+ voices) |
| - Real-time generation with progress tracking |
| - WAV, MP3, and AAC output formats |
| - REST API for programmatic access |
| - Network sharing capabilities |
| - Cross-platform compatibility (Windows, macOS, Linux) |
| - Configurable caching and model management |
| """ |
|
|
| import gradio as gr |
| import json |
| import platform |
|
|
| import shutil |
| from pathlib import Path |
| import soundfile as sf |
| from pydub import AudioSegment |
| import torch |
| import numpy as np |
| import time |
| import uuid |
| from typing import Dict, List, Optional, Union, Tuple, Generator |
| import threading |
| import os |
| import sys |
| import time |
| import socket |
| import threading |
| import logging |
| from datetime import datetime |
| from werkzeug.middleware.dispatcher import DispatcherMiddleware |
| from werkzeug.serving import run_simple |
| |
| from models import ( |
| list_available_voices, build_model, |
| generate_speech |
| ) |
|
|
| |
| from flask import Flask, request, jsonify, send_file |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.StreamHandler(), |
| logging.FileHandler("kokoro_tts.log") |
| ] |
| ) |
| logger = logging.getLogger("kokoro_tts") |
|
|
| |
| CONFIG_FILE = "tts_config.json" |
| DEFAULT_OUTPUT_DIR = "outputs" |
| SAMPLE_RATE = 24000 |
|
|
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| logger.info(f"Using device: {device}") |
| model = None |
| config = { |
| "output_dir": DEFAULT_OUTPUT_DIR, |
| "default_voice": None, |
| "default_format": "wav", |
| "api_enabled": True, |
| "api_port": 5000, |
| "ui_port": 7860, |
| "share_ui": True |
| } |
|
|
| def load_config() -> Dict: |
| """Load configuration from file or create default.""" |
| try: |
| if os.path.exists(CONFIG_FILE): |
| with open(CONFIG_FILE, 'r') as f: |
| loaded_config = json.load(f) |
| |
| for k, v in config.items(): |
| if k not in loaded_config: |
| loaded_config[k] = v |
| return loaded_config |
| else: |
| save_config(config) |
| return config |
| except Exception as e: |
| logger.error(f"Error loading config: {e}") |
| return config |
|
|
| def save_config(config_data: Dict) -> None: |
| """Save configuration to file.""" |
| try: |
| with open(CONFIG_FILE, 'w') as f: |
| json.dump(config_data, f, indent=4) |
| except Exception as e: |
| logger.error(f"Error saving config: {e}") |
|
|
| def initialize_model() -> None: |
| """Initialize the TTS model.""" |
| global model |
| try: |
| if model is None: |
| logger.info("Initializing Kokoro TTS model...") |
| model = build_model(None, device) |
| logger.info("Model initialization complete") |
| except Exception as e: |
| logger.error(f"Error initializing model: {e}") |
| raise |
|
|
| def get_available_voices() -> List[str]: |
| """Get list of available voice models.""" |
| try: |
| |
| initialize_model() |
| |
| voices = list_available_voices() |
| if not voices: |
| logger.warning("No voices found after initialization.") |
| |
| logger.info(f"Available voices: {voices}") |
| return voices |
| except Exception as e: |
| logger.error(f"Error getting voices: {e}") |
| return [] |
|
|
| def convert_audio(input_path: str, output_format: str) -> str: |
| """Convert audio to specified format.""" |
| try: |
| if output_format == "wav": |
| return input_path |
| |
| output_path = os.path.splitext(input_path)[0] + f".{output_format}" |
| audio = AudioSegment.from_wav(input_path) |
| |
| if output_format == "mp3": |
| audio.export(output_path, format="mp3", bitrate="192k") |
| elif output_format == "aac": |
| audio.export(output_path, format="aac", bitrate="192k") |
| else: |
| logger.warning(f"Unsupported format: {output_format}, defaulting to wav") |
| return input_path |
| |
| logger.info(f"Converted audio to {output_format}: {output_path}") |
| return output_path |
| except Exception as e: |
| logger.error(f"Error converting audio: {e}") |
| return input_path |
|
|
| def generate_tts( |
| text: str, |
| voice_name: str, |
| output_format: str = "wav", |
| output_path: Optional[str] = None, |
| speed: float = 1.0 |
| ) -> Optional[str]: |
| """ |
| Generate TTS audio and return the path to the generated file. |
| |
| Args: |
| text: Text to convert to speech |
| voice_name: Name of the voice to use |
| output_format: Output audio format (wav, mp3, aac) |
| output_path: Optional custom output path |
| speed: Speech speed multiplier |
| |
| Returns: |
| Path to the generated audio file, or None if generation failed |
| """ |
| global model |
| |
| try: |
| |
| initialize_model() |
| |
| |
| os.makedirs(config["output_dir"], exist_ok=True) |
| |
| |
| if output_path: |
| base_path = output_path |
| wav_path = os.path.splitext(base_path)[0] + ".wav" |
| else: |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| base_name = f"tts_{timestamp}_{str(uuid.uuid4())[:8]}" |
| wav_path = os.path.join(config["output_dir"], f"{base_name}.wav") |
| |
| |
| logger.info(f"Generating speech for text: '{text[:50]}...' using voice: {voice_name}") |
| |
| |
| voice_path = f"voices/{voice_name}.pt" |
| if not os.path.exists(voice_path): |
| logger.warning(f"Voice file not found: {voice_path}") |
| voices = get_available_voices() |
| if not voices: |
| raise Exception("No voices available") |
| if voice_name not in voices: |
| logger.warning(f"Using default voice instead of {voice_name}") |
| voice_name = voices[0] |
| voice_path = f"voices/{voice_name}.pt" |
| |
| |
| generator = model(text, voice=voice_path, speed=speed, split_pattern=r'\n+') |
| |
| all_audio = [] |
| for i, (gs, ps, audio) in enumerate(generator): |
| if audio is not None: |
| if isinstance(audio, np.ndarray): |
| audio = torch.from_numpy(audio).float() |
| all_audio.append(audio) |
| logger.debug(f"Generated segment {i+1}: {gs[:30]}...") |
| |
| if not all_audio: |
| raise Exception("No audio generated") |
| |
| |
| final_audio = torch.cat(all_audio, dim=0) |
| sf.write(wav_path, final_audio.numpy(), SAMPLE_RATE) |
| logger.info(f"Saved WAV file to {wav_path}") |
| |
| |
| if output_format != "wav": |
| output_file = convert_audio(wav_path, output_format) |
| return output_file |
| |
| return wav_path |
| |
| except Exception as e: |
| logger.error(f"Error generating speech: {e}") |
| import traceback |
| logger.error(traceback.format_exc()) |
| return None |
|
|
| |
| def create_ui_interface(): |
| """Create and return the Gradio interface.""" |
| |
| |
| voices = get_available_voices() |
| if not voices: |
| logger.error("No voices found! Please check the voices directory.") |
| |
| voices = [] |
| |
| |
| default_voice = config.get("default_voice") |
| if not default_voice or default_voice not in voices: |
| default_voice = voices[0] if voices else None |
| if default_voice: |
| config["default_voice"] = default_voice |
| save_config(config) |
| |
| |
| with gr.Blocks(title="CB's TTS Generator") as interface: |
| gr.Markdown("# **Welcome to CB's TTS Generator**") |
| gr.Markdown("There are multiple voices available for you to choose. This TTS is powered by Kokoro.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| voice = gr.Dropdown( |
| choices=voices, |
| value=default_voice, |
| label="Voice" |
| ) |
| |
| text = gr.Textbox( |
| lines=8, |
| placeholder="Enter text to convert to speech...", |
| label="Text Input" |
| ) |
| |
| format_choice = gr.Radio( |
| choices=["wav", "mp3", "aac"], |
| value=config.get("default_format", "wav"), |
| label="Output Format" |
| ) |
| |
| speed = gr.Slider( |
| minimum=0.5, |
| maximum=2.0, |
| value=1.0, |
| step=0.1, |
| label="Speech Speed" |
| ) |
| |
| generate_btn = gr.Button("Generate Speech", variant="primary") |
| |
| with gr.Column(scale=1): |
| output = gr.Audio(label="Generated Audio") |
| |
| status = gr.Textbox(label="Status", interactive=False) |
| |
| |
| def generate_wrapper(voice_name, text_input, format_choice, speed_value): |
| if not text_input.strip(): |
| return None, "Error: Please enter some text to convert." |
| |
| try: |
| output_path = generate_tts( |
| text=text_input, |
| voice_name=voice_name, |
| output_format=format_choice, |
| speed=speed_value |
| ) |
| |
| if output_path: |
| return output_path, f"Success! Generated audio with voice: {voice_name}" |
| else: |
| return None, "Error: Failed to generate audio. Check logs for details." |
| except Exception as e: |
| logger.error(f"UI generation error: {e}") |
| return None, f"Error: {str(e)}" |
| |
| generate_btn.click( |
| fn=generate_wrapper, |
| inputs=[voice, text, format_choice, speed], |
| outputs=[output, status] |
| ) |
| |
| |
| if voices: |
| gr.Examples( |
| [ |
| ["May the Force be with you.", default_voice, "wav", 1.0], |
| ["Here's looking at you, kid.", default_voice, "mp3", 1.0], |
| ["I'll be back.", default_voice, "wav", 1.0], |
| ["Houston, we have a problem.", default_voice, "mp3", 1.0] |
| ], |
| fn=generate_wrapper, |
| inputs=[text, voice, format_choice, speed], |
| outputs=[output, status] |
| ) |
| |
| return interface |
|
|
| |
| def create_api_server() -> Flask: |
| """Create and configure the Flask API server.""" |
| app = Flask("KokoroTTS-API") |
| |
| @app.route('/api/voices', methods=['GET']) |
| def api_voices(): |
| """Get available voices.""" |
| try: |
| voices = get_available_voices() |
| return jsonify({"voices": voices, "default": config.get("default_voice")}) |
| except Exception as e: |
| logger.error(f"API error in voices: {e}") |
| return jsonify({"error": str(e)}), 500 |
| |
| @app.route('/api/tts', methods=['POST']) |
| def api_tts(): |
| """Generate speech from text.""" |
| try: |
| data = request.json |
| |
| if not data or 'text' not in data: |
| return jsonify({"error": "Missing 'text' field"}), 400 |
| |
| text = data['text'] |
| voice = data.get('voice', config.get("default_voice")) |
| output_format = data.get('format', config.get("default_format", "wav")) |
| speed = float(data.get('speed', 1.0)) |
| |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| request_id = str(uuid.uuid4())[:8] |
| filename = f"api_tts_{timestamp}_{request_id}.{output_format}" |
| output_path = os.path.join(config["output_dir"], filename) |
| |
| |
| generated_path = generate_tts( |
| text=text, |
| voice_name=voice, |
| output_format=output_format, |
| output_path=output_path, |
| speed=speed |
| ) |
| |
| if not generated_path or not os.path.exists(generated_path): |
| logger.error(f"Generated path doesn't exist: {generated_path}") |
| return jsonify({"error": "Failed to generate audio file"}), 500 |
| |
| |
| file_size = os.path.getsize(generated_path) |
| if file_size < 100: |
| logger.error(f"Generated file is too small ({file_size} bytes)") |
| return jsonify({"error": "Generated audio file appears to be empty or corrupted"}), 500 |
| |
| logger.info(f"Sending audio file: {generated_path} ({file_size} bytes)") |
| |
| |
| return send_file( |
| generated_path, |
| as_attachment=True, |
| download_name=f"tts_output.{output_format}", |
| mimetype=f"audio/{output_format}" if output_format != "aac" else "audio/aac" |
| ) |
| |
| except Exception as e: |
| logger.error(f"API error in TTS: {e}") |
| import traceback |
| logger.error(traceback.format_exc()) |
| return jsonify({"error": str(e)}), 500 |
| |
| |
| @app.route('/api/health', methods=['GET']) |
| def api_health(): |
| """Health check endpoint.""" |
| return jsonify({ |
| "status": "ok", |
| "model_loaded": model is not None, |
| "voices_count": len(get_available_voices()) |
| }) |
| |
| @app.route('/api/config', methods=['GET', 'PUT']) |
| def api_config(): |
| """Get or update configuration.""" |
| if request.method == 'GET': |
| return jsonify(config) |
| else: |
| try: |
| data = request.json |
| |
| for key in ['output_dir', 'default_voice', 'default_format']: |
| if key in data: |
| config[key] = data[key] |
| |
| save_config(config) |
| return jsonify({"status": "success", "config": config}) |
| except Exception as e: |
| logger.error(f"API error updating config: {e}") |
| return jsonify({"error": str(e)}), 500 |
| |
| return app |
|
|
| |
| def launch_api(host="0.0.0.0", port=None): |
| """Launch the API server in a separate thread.""" |
| if not config.get("api_enabled", True): |
| logger.info("API server disabled in configuration") |
| return |
| |
| api_port = port or config.get("api_port", 5000) |
| logger.info(f"Launching API server on port {api_port}") |
| |
| app = create_api_server() |
| |
| def run_api_server(): |
| try: |
| |
| from werkzeug.serving import run_simple |
| run_simple(host, api_port, app, threaded=True, use_reloader=False) |
| except Exception as e: |
| logger.error(f"Error in API server: {e}") |
| import traceback |
| logger.error(traceback.format_exc()) |
| |
| |
| api_thread = threading.Thread(target=run_api_server, daemon=True) |
| api_thread.start() |
| |
| |
| time.sleep(1) |
| logger.info(f"API server running at http://{host}:{api_port}") |
| return api_thread |
|
|
| def launch_ui(server_name="0.0.0.0", server_port=None, share=None): |
| port = server_port or config.get("ui_port", 7860) |
| share_ui = share if share is not None else config.get("share_ui", True) |
| |
| logger.info(f"Launching UI on port {port} (share={share_ui})") |
| interface = create_ui_interface() |
| |
| |
| if os.environ.get("HF_SPACE") is None: |
| interface.queue() |
| |
| interface.launch( |
| server_name=server_name, |
| server_port=port, |
| share=share_ui, |
| prevent_thread_lock=True |
| ) |
| logger.info(f"UI server running at http://{server_name}:{port}") |
| return True |
|
|
|
|
|
|
| |
| def main(): |
| """Main application entry point.""" |
| print("\n" + "="*50) |
| print("Starting Kokoro-TTS") |
| print("="*50) |
| |
| |
| global config |
| config = load_config() |
| os.makedirs(config["output_dir"], exist_ok=True) |
| |
| |
| try: |
| initialize_model() |
| except Exception as e: |
| logger.error(f"Failed to initialize model: {e}") |
| print(f"ERROR: Failed to initialize model: {e}") |
| sys.exit(1) |
| |
| |
| hostname = socket.gethostname() |
| network_ip = socket.gethostbyname(hostname) |
| |
| |
| if os.environ.get("HF_SPACE") is not None or os.environ.get("SINGLE_PORT") == "1": |
| |
| api_app = create_api_server() |
| interface = create_ui_interface() |
|
|
| |
| |
| combined_app = DispatcherMiddleware(interface.app, { |
| '/api': api_app |
| }) |
|
|
| |
| port = config.get("ui_port", 7860) |
| print(f"Combined UI and API running on port: {port}") |
| print(f"Localhost: http://localhost:{port}") |
| print(f"Network: http://{network_ip}:{port}") |
| |
| |
| run_simple("0.0.0.0", port, combined_app, use_reloader=False, threaded=True) |
| else: |
| |
| if config.get("api_enabled", True): |
| launch_api() |
| ui_thread = threading.Thread(target=launch_ui, daemon=True) |
| ui_thread.start() |
| |
| print(f"UI (localhost): http://localhost:{config.get('ui_port', 7860)}") |
| print(f"UI (network): http://{network_ip}:{config.get('ui_port', 7860)}") |
| if config.get("api_enabled", True): |
| print(f"API (localhost): http://localhost:{config.get('api_port', 5000)}") |
| print(f"API (network): http://{network_ip}:{config.get('api_port', 5000)}") |
| |
| |
| try: |
| while True: |
| time.sleep(1) |
| except KeyboardInterrupt: |
| print("\nShutting down servers...") |
| print("Press Ctrl+C again to force quit") |
|
|
| if __name__ == "__main__": |
| main() |