| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import traceback |
| import whisper |
| import librosa |
| import numpy as np |
| import torch |
| import uvicorn |
| import base64 |
| import io |
| import re |
| import json |
| import asyncio |
| import tempfile |
| import os |
| try: |
| import edge_tts |
| TTS_AVAILABLE = True |
| except ImportError: |
| TTS_AVAILABLE = False |
|
|
| try: |
| from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference |
| from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor |
| import soundfile as sf |
| VIBEVOICE_AVAILABLE = True |
| except ImportError: |
| VIBEVOICE_AVAILABLE = False |
|
|
| asr_model = whisper.load_model("models/wpt/wpt.pt") |
| model_name = "models/Llama-3.2-1B-Instruct" |
| tok = AutoTokenizer.from_pretrained(model_name) |
| lm = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| device_map="cuda", |
| ).eval() |
|
|
| |
| vibevoice_model = None |
| vibevoice_processor = None |
| vibevoice_voice_sample = None |
| if VIBEVOICE_AVAILABLE: |
| try: |
| vibevoice_model_path = os.getenv("VIBEVOICE_MODEL_PATH", "models/VibeVoice-1.5B") |
| vibevoice_voice_path = os.getenv("VIBEVOICE_VOICE_PATH", None) |
| vibevoice_tokenizer_path = os.getenv("VIBEVOICE_TOKENIZER_PATH", "models/Qwen2.5-1.5B") |
| |
| |
| if vibevoice_model_path and not os.path.isabs(vibevoice_model_path): |
| vibevoice_model_path = os.path.abspath(vibevoice_model_path) |
| if vibevoice_tokenizer_path and not os.path.isabs(vibevoice_tokenizer_path): |
| vibevoice_tokenizer_path = os.path.abspath(vibevoice_tokenizer_path) |
| if vibevoice_voice_path and not os.path.isabs(vibevoice_voice_path): |
| vibevoice_voice_path = os.path.abspath(vibevoice_voice_path) |
| |
| |
| if not vibevoice_tokenizer_path: |
| |
| local_qwen_paths = [ |
| "models/Qwen2.5-1.5B", |
| "models/Qwen/Qwen2.5-1.5B", |
| os.path.join(vibevoice_model_path, "tokenizer"), |
| ] |
| for qwen_path in local_qwen_paths: |
| if os.path.exists(qwen_path) and os.path.isdir(qwen_path): |
| |
| tokenizer_files = ["tokenizer_config.json", "vocab.json", "merges.txt"] |
| if any(os.path.exists(os.path.join(qwen_path, f)) for f in tokenizer_files): |
| vibevoice_tokenizer_path = qwen_path |
| print(f"Found local Qwen tokenizer at {qwen_path}") |
| break |
| |
| print(f"Loading VibeVoice processor from {vibevoice_model_path}") |
| |
| |
| preprocessor_config_path = os.path.join(vibevoice_model_path, "preprocessor_config.json") |
| config_modified = False |
| original_config = None |
| original_tokenizer_path = None |
| |
| if vibevoice_tokenizer_path and os.path.exists(preprocessor_config_path): |
| try: |
| import json |
| |
| with open(preprocessor_config_path, 'r') as f: |
| original_config = json.load(f) |
| |
| |
| original_tokenizer_path = original_config.get("language_model_pretrained_name", "") |
| if original_tokenizer_path != vibevoice_tokenizer_path: |
| |
| original_config["language_model_pretrained_name"] = vibevoice_tokenizer_path |
| with open(preprocessor_config_path, 'w') as f: |
| json.dump(original_config, f, indent=2) |
| config_modified = True |
| print(f"Updated preprocessor_config.json to use local tokenizer: {vibevoice_tokenizer_path}") |
| except Exception as config_error: |
| print(f"Warning: Could not modify preprocessor_config.json: {config_error}") |
| |
| |
| processor_kwargs = {} |
| if vibevoice_tokenizer_path: |
| processor_kwargs["language_model_pretrained_name"] = vibevoice_tokenizer_path |
| print(f"Using tokenizer from: {vibevoice_tokenizer_path}") |
| |
| try: |
| vibevoice_processor = VibeVoiceProcessor.from_pretrained(vibevoice_model_path, **processor_kwargs) |
| finally: |
| |
| if config_modified and original_config is not None and original_tokenizer_path is not None: |
| try: |
| |
| original_config["language_model_pretrained_name"] = original_tokenizer_path |
| with open(preprocessor_config_path, 'w') as f: |
| json.dump(original_config, f, indent=2) |
| except Exception: |
| pass |
| |
| |
| |
| |
| |
| print(f"Loading VibeVoice model from {vibevoice_model_path}") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| load_dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
| attn_impl = "flash_attention_2" if device == "cuda" else "sdpa" |
| |
| try: |
| vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
| vibevoice_model_path, |
| torch_dtype=load_dtype, |
| device_map=device if device == "cuda" else None, |
| attn_implementation=attn_impl, |
| ) |
| if device != "cuda": |
| vibevoice_model.to(device) |
| except Exception as e: |
| if attn_impl == "flash_attention_2": |
| print(f"Failed to load with flash_attention_2, falling back to sdpa: {e}") |
| vibevoice_model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
| vibevoice_model_path, |
| torch_dtype=load_dtype, |
| device_map=device if device in ("cuda", "cpu") else None, |
| attn_implementation="sdpa", |
| ) |
| if device not in ("cuda", "cpu"): |
| vibevoice_model.to(device) |
| else: |
| raise |
| |
| vibevoice_model.eval() |
| vibevoice_model.set_ddpm_inference_steps(num_steps=10) |
| |
| |
| if vibevoice_voice_path and os.path.exists(vibevoice_voice_path) and os.path.isfile(vibevoice_voice_path): |
| print(f"Loading voice sample from {vibevoice_voice_path}") |
| try: |
| wav, sr = sf.read(vibevoice_voice_path) |
| if len(wav.shape) > 1: |
| wav = np.mean(wav, axis=1) |
| if sr != 24000: |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=24000) |
| vibevoice_voice_sample = wav.astype(np.float32) |
| except Exception as voice_error: |
| print(f"Warning: Could not load voice sample from {vibevoice_voice_path}: {voice_error}") |
| vibevoice_voice_sample = None |
| else: |
| |
| default_voice_paths = [ |
| "/app/spk_001.wav", |
| "spk_001.wav", |
| "/home/user/VibeVoice/demo/voices/en-Alice_woman.wav", |
| "demo/voices/en-Alice_woman.wav", |
| "VibeVoice/demo/voices/en-Alice_woman.wav", |
| ] |
| for voice_path in default_voice_paths: |
| if os.path.exists(voice_path): |
| print(f"Loading default voice sample from {voice_path}") |
| wav, sr = sf.read(voice_path) |
| if len(wav.shape) > 1: |
| wav = np.mean(wav, axis=1) |
| if sr != 24000: |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=24000) |
| vibevoice_voice_sample = wav.astype(np.float32) |
| break |
| |
| if vibevoice_voice_sample is None: |
| print("Warning: No voice sample found. VibeVoice will work without voice cloning.") |
| |
| print("VibeVoice initialized successfully") |
| except Exception as e: |
| print(f"Failed to initialize VibeVoice: {e}") |
| traceback.print_exc() |
| VIBEVOICE_AVAILABLE = False |
| vibevoice_model = None |
| vibevoice_processor = None |
| class EvalHandler: |
| def __init__(self): |
| self.rule_patterns = { |
| 'comma_restriction': re.compile(r'no.*comma|without.*comma', re.IGNORECASE), |
| 'placeholder_requirement': re.compile(r'placeholder.*\[.*\]|square.*bracket', re.IGNORECASE), |
| 'lowercase_requirement': re.compile(r'lowercase|no.*capital|all.*lowercase', re.IGNORECASE), |
| 'capital_frequency': re.compile(r'capital.*letter.*less.*than|capital.*word.*frequency', re.IGNORECASE), |
| 'quotation_requirement': re.compile(r'wrap.*quotation|double.*quote', re.IGNORECASE), |
| 'json_format': re.compile(r'json.*format|JSON.*output|format.*json', re.IGNORECASE), |
| 'word_count': re.compile(r'less.*than.*word|word.*limit|maximum.*word', re.IGNORECASE), |
| 'section_requirement': re.compile(r'section.*start|SECTION.*X', re.IGNORECASE), |
| 'ending_requirement': re.compile(r'finish.*exact.*phrase|end.*phrase', re.IGNORECASE), |
| 'forbidden_words': re.compile(r'not.*allowed|forbidden.*word|without.*word', re.IGNORECASE), |
| 'capital_letters_only': re.compile(r'all.*capital|CAPITAL.*letter', re.IGNORECASE) |
| } |
|
|
| def detect_rules(self, instruction): |
| applicable_rules = [] |
| if self.rule_patterns['comma_restriction'].search(instruction): |
| applicable_rules.append('CommaChecker') |
| if self.rule_patterns['placeholder_requirement'].search(instruction): |
| applicable_rules.append('PlaceholderChecker') |
| if self.rule_patterns['lowercase_requirement'].search(instruction): |
| applicable_rules.append('LowercaseLettersEnglishChecker') |
| if self.rule_patterns['capital_frequency'].search(instruction): |
| applicable_rules.append('CapitalWordFrequencyChecker') |
| if self.rule_patterns['quotation_requirement'].search(instruction): |
| applicable_rules.append('QuotationChecker') |
| if self.rule_patterns['json_format'].search(instruction): |
| applicable_rules.append('JsonFormat') |
| if self.rule_patterns['word_count'].search(instruction): |
| applicable_rules.append('NumberOfWords') |
| if self.rule_patterns['section_requirement'].search(instruction): |
| applicable_rules.append('SectionChecker') |
| if self.rule_patterns['ending_requirement'].search(instruction): |
| applicable_rules.append('EndChecker') |
| if self.rule_patterns['forbidden_words'].search(instruction): |
| applicable_rules.append('ForbiddenWords') |
| if self.rule_patterns['capital_letters_only'].search(instruction): |
| applicable_rules.append('CapitalLettersEnglishChecker') |
| return applicable_rules |
|
|
| def apply_rule_fix(self, response, rules, instruction= ""): |
| for rule in rules: |
| if rule == 'CommaChecker': |
| response = self._fix_commas(response, instruction) |
| elif rule == 'PlaceholderChecker': |
| response = self._fix_placeholders(response, instruction) |
| elif rule == 'LowercaseLettersEnglishChecker': |
| response = self._fix_lowercase(response) |
| elif rule == 'CapitalWordFrequencyChecker': |
| response = self._fix_capital_frequency(response, instruction) |
| elif rule == 'QuotationChecker': |
| response = self._fix_quotations(response) |
| elif rule == 'JsonFormat': |
| response = self._fix_json_format(response, instruction) |
| elif rule == 'NumberOfWords': |
| response = self._fix_word_count(response, instruction) |
| elif rule == 'SectionChecker': |
| response = self._fix_sections(response, instruction) |
| elif rule == 'EndChecker': |
| response = self._fix_ending(response, instruction) |
| elif rule == 'ForbiddenWords': |
| response = self._fix_forbidden_words(response, instruction) |
| elif rule == 'CapitalLettersEnglishChecker': |
| response = self._fix_all_capitals(response, instruction) |
| return response |
|
|
| def _fix_commas(self, response, instruction): |
| return response.replace(',', '') |
|
|
| def _fix_placeholders(self, response, instruction): |
| num_match = re.search(r'at least (\d+)', instruction, re.IGNORECASE) |
| if num_match: |
| target_count = int(num_match.group(1)) |
| current_count = len(re.findall(r'\[.*?\]', response)) |
| words = response.split() |
| for i in range(target_count - current_count): |
| if i < len(words): |
| words[i] = f'[{words[i]}]' |
| return ' '.join(words) |
| return response |
|
|
| def _fix_lowercase(self, response): |
| return response.lower() |
|
|
| def _fix_capital_frequency(self, response, instruction): |
| max_match = re.search(r'less than (\d+)', instruction, re.IGNORECASE) |
| if max_match: |
| max_capitals = int(max_match.group(1)) |
| words = response.split() |
| capital_count = sum(1 for word in words if word.isupper()) |
| if capital_count > max_capitals: |
| for i, word in enumerate(words): |
| if word.isupper() and capital_count > max_capitals: |
| words[i] = word.lower() |
| capital_count -= 1 |
| return ' '.join(words) |
| return response |
|
|
| def _fix_quotations(self, response): |
| return f'"{response}"' |
|
|
| def _fix_json_format(self, response, instruction): |
| return json.dumps({"response": response}, indent=2) |
|
|
| def _fix_word_count(self, response, instruction): |
| limit_match = re.search(r'less than (\d+)', instruction, re.IGNORECASE) |
| if limit_match: |
| word_limit = int(limit_match.group(1)) |
| words = response.split() |
|
|
| if len(words) > word_limit: |
| return ' '.join(words[:word_limit]) |
| return response |
|
|
| def _fix_sections(self, response, instruction): |
| section_match = re.search(r'(\d+) section', instruction, re.IGNORECASE) |
| if section_match: |
| num_sections = int(section_match.group(1)) |
| sections = [] |
|
|
| for i in range(num_sections): |
| sections.append(f"SECTION {i+1}:") |
| sections.append("This section provides content here.") |
|
|
| return '\n\n'.join(sections) |
| return response |
|
|
| def _fix_ending(self, response, instruction): |
| end_match = re.search(r'finish.*with.*phrase[:\s]*([^.!?]*)', instruction, re.IGNORECASE) |
| if end_match: |
| required_ending = end_match.group(1).strip() |
| if not response.endswith(required_ending): |
| return response + " " + required_ending |
| return response |
|
|
| def _fix_forbidden_words(self, response, instruction): |
| forbidden_match = re.search(r'without.*word[:\s]*([^.!?]*)', instruction, re.IGNORECASE) |
| if forbidden_match: |
| forbidden_word = forbidden_match.group(1).strip().lower() |
| response = re.sub(re.escape(forbidden_word), '', response, flags=re.IGNORECASE) |
| return response.strip() |
|
|
| def _fix_all_capitals(self, response, instruction): |
| return response.upper() |
|
|
| EVAL_HANDLER = EvalHandler() |
|
|
| def chat(system_prompt: str, user_prompt: str) -> str: |
| """ |
| Run one turn of chat with a system + user message. |
| Extra **gen_kwargs are forwarded to `generate()`. |
| """ |
| try: |
| global EVAL_HANDLER |
| if EVAL_HANDLER is None: |
| EVAL_HANDLER = EvalHandler() |
| applicable_rules = EVAL_HANDLER.detect_rules(user_prompt) |
| system_prompt_parts = [] |
| if applicable_rules: |
| if 'CommaChecker' in applicable_rules: |
| system_prompt_parts.append("Do not use any commas in your response.") |
| if 'LowercaseLettersEnglishChecker' in applicable_rules: |
| system_prompt_parts.append("Respond in all lowercase letters only.") |
| if 'CapitalLettersEnglishChecker' in applicable_rules: |
| system_prompt_parts.append("Respond in ALL CAPITAL LETTERS.") |
| if 'QuotationChecker' in applicable_rules: |
| system_prompt_parts.append("Wrap your entire response in double quotation marks.") |
| if 'JsonFormat' in applicable_rules: |
| system_prompt_parts.append("Format your response as valid JSON.") |
| if 'SectionChecker' in applicable_rules: |
| system_prompt_parts.append("Organize your response into clearly marked sections.") |
| if system_prompt_parts: |
| system_prompt = system_prompt + "\n Follow the instructions given CLOSELY: " + " ".join(system_prompt_parts) |
| except Exception as e: |
| system_prompt = system_prompt |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ] |
| inputs = tok.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True |
| ) |
| input_ids = inputs["input_ids"].to(lm.device) |
| attention_mask = inputs["attention_mask"].to(lm.device) |
| with torch.inference_mode(): |
| output_ids = lm.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pad_token_id=tok.eos_token_id, |
| max_new_tokens=2048, |
| do_sample=True, |
| temperature=0.2, |
| repetition_penalty=1.1, |
| top_k=100, |
| top_p=0.9, |
| ) |
| answer = tok.decode( |
| output_ids[0][input_ids.shape[-1]:], |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=True, |
| ) |
| return answer.strip() |
|
|
| def gt(audio: np.ndarray, sr: int): |
| ss = audio.squeeze().astype(np.float32) |
| if sr != 16_000: |
| ss = librosa.resample(audio, orig_sr=sr, target_sr=16_000) |
| result = asr_model.transcribe(ss, fp16=False, language=None) |
| transcribed_text = result["text"].strip() |
| return transcribed_text |
|
|
| def sample(rr: str) -> str: |
| if rr.strip() == "": rr = "Hello " |
| inputs = tok(rr, return_tensors="pt").to(lm.device) |
| with torch.inference_mode(): |
| out_ids = lm.generate( |
| **inputs, |
| max_new_tokens=2048, |
| do_sample=True, |
| temperature=0.2, |
| repetition_penalty=1.1, |
| top_k=100, |
| top_p=0.95, |
| ) |
| return tok.decode( |
| out_ids[0][inputs.input_ids.shape[-1] :], skip_special_tokens=True |
| ) |
|
|
| def text_to_speech_vibevoice(text: str) -> np.ndarray: |
| """ |
| Convert text to speech using VibeVoice (synchronous). |
| |
| Args: |
| text: Text to convert to speech |
| |
| Returns: |
| Audio array as numpy array (mono, 16kHz) or None if failed |
| """ |
| global vibevoice_model, vibevoice_processor, vibevoice_voice_sample |
| |
| if not VIBEVOICE_AVAILABLE or vibevoice_model is None or vibevoice_processor is None: |
| return None |
| |
| try: |
| if not text or not text.strip(): |
| return np.zeros(16000, dtype=np.float32) |
| |
| |
| |
| |
| lines = text.strip().split('\n') |
| formatted_lines = [] |
| for line in lines: |
| line = line.strip() |
| if line: |
| |
| formatted_lines.append(f"Speaker 1: {line}") |
| formatted_text = '\n'.join(formatted_lines) |
| |
| |
| processor_kwargs = { |
| "text": [formatted_text], |
| "padding": True, |
| "return_tensors": "pt", |
| "return_attention_mask": True, |
| } |
| |
| |
| if vibevoice_voice_sample is not None: |
| processor_kwargs["voice_samples"] = [[vibevoice_voice_sample]] |
| |
| inputs = vibevoice_processor(**processor_kwargs) |
| |
| |
| device = next(vibevoice_model.parameters()).device |
| for k, v in inputs.items(): |
| if torch.is_tensor(v): |
| inputs[k] = v.to(device) |
| |
| |
| with torch.inference_mode(): |
| outputs = vibevoice_model.generate( |
| **inputs, |
| max_new_tokens=None, |
| cfg_scale=1.3, |
| tokenizer=vibevoice_processor.tokenizer, |
| generation_config={'do_sample': False}, |
| verbose=False, |
| is_prefill=(vibevoice_voice_sample is not None), |
| ) |
| |
| |
| if outputs.speech_outputs and outputs.speech_outputs[0] is not None: |
| audio_tensor = outputs.speech_outputs[0] |
| |
| |
| if torch.is_tensor(audio_tensor): |
| if audio_tensor.dtype == torch.bfloat16: |
| audio_tensor = audio_tensor.float() |
| audio_array = audio_tensor.cpu().numpy().astype(np.float32) |
| else: |
| audio_array = np.array(audio_tensor, dtype=np.float32) |
| |
| |
| if len(audio_array.shape) > 1: |
| audio_array = audio_array.squeeze() |
| |
| |
| if len(audio_array) > 0: |
| audio_array = librosa.resample(audio_array, orig_sr=24000, target_sr=16000) |
| return audio_array.astype(np.float32) |
| else: |
| return np.zeros(16000, dtype=np.float32) |
| else: |
| return np.zeros(16000, dtype=np.float32) |
| |
| except Exception as e: |
| print(f"VibeVoice generation failed: {e}") |
| traceback.print_exc() |
| return None |
|
|
| async def text_to_speech_edge_tts(text: str, voice: str = "en-US-AriaNeural") -> np.ndarray: |
| """ |
| Convert text to speech using edge-tts (async). |
| |
| Args: |
| text: Text to convert to speech |
| voice: Voice to use (default: en-US-AriaNeural) |
| |
| Returns: |
| Audio array as numpy array (mono, 16kHz) |
| """ |
| if not TTS_AVAILABLE: |
| raise RuntimeError("edge-tts not available") |
| |
| try: |
| |
| communicate = edge_tts.Communicate(text, voice) |
| audio_data = b"" |
| async for chunk in communicate.stream(): |
| if chunk["type"] == "audio": |
| audio_data += chunk["data"] |
| |
| if not audio_data: |
| return np.zeros(16000, dtype=np.float32) |
| |
| |
| |
| |
| tmp_file_path = None |
| try: |
| |
| with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp_file: |
| tmp_file.write(audio_data) |
| tmp_file.flush() |
| tmp_file_path = tmp_file.name |
| |
| |
| |
| audio_array, sample_rate = librosa.load(tmp_file_path, sr=None, mono=True) |
| |
| |
| if sample_rate != 16000: |
| audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000) |
| sample_rate = 16000 |
| |
| return audio_array.astype(np.float32) |
| finally: |
| |
| if tmp_file_path and os.path.exists(tmp_file_path): |
| try: |
| os.unlink(tmp_file_path) |
| except Exception: |
| pass |
| |
| except Exception as e: |
| raise |
|
|
| def clean_v2t_response_for_v2v(response_text: str) -> str: |
| """ |
| Post-process v2t response to remove the first two lines before using for t2v. |
| |
| The v2t response typically contains: |
| - Line 1: The actual response text (often the input prompt repeated) |
| - Line 2: Injected scoring line (e.g., "As an answer 5 points with scale from 5 to 10...") |
| - Line 3+: The actual useful response content |
| |
| This function removes the first two lines to get the clean response for TTS. |
| |
| Args: |
| response_text: Full response text from v2t endpoint |
| |
| Returns: |
| Cleaned text with first two lines removed |
| """ |
| if not response_text: |
| return "" |
| |
| lines = response_text.split("\n") |
| |
| |
| if len(lines) >= 3: |
| |
| cleaned_lines = lines[2:] |
| cleaned_text = "\n".join(cleaned_lines).strip() |
| |
| |
| if not cleaned_text and len(lines) >= 2: |
| cleaned_text = "\n".join(lines[1:]).strip() |
| |
| |
| if not cleaned_text: |
| cleaned_text = response_text.strip() |
| |
| return cleaned_text |
| elif len(lines) == 2: |
| |
| cleaned_text = lines[1].strip() |
| return cleaned_text |
| else: |
| |
| return response_text.strip() |
|
|
|
|
| def clean_text_for_tts_with_llm(text: str) -> str: |
| """ |
| Use LLM to intelligently clean text for text-to-speech while preserving important content. |
| |
| This function sends the text to the LLM with instructions to: |
| - Remove unicode characters, symbols, and formatting that don't contribute to speech |
| - Preserve important content like math equations (convert to spoken form) |
| - Keep all meaningful words, numbers, and essential punctuation |
| - Make the text natural and clear for TTS |
| |
| Args: |
| text: Text to clean for TTS |
| |
| Returns: |
| Cleaned text optimized for text-to-speech |
| """ |
| if not text or not text.strip(): |
| return "" |
| |
| global tok, lm |
| if tok is None or lm is None: |
| return text |
| |
| try: |
| |
| system_prompt = """You are a text cleaning assistant. Your task is to clean text for text-to-speech (TTS) conversion. |
| |
| IMPORTANT RULES: |
| 1. Remove all unicode characters, special symbols, and formatting that don't contribute to speech |
| 2. PRESERVE important content: |
| - Math equations: Convert them to spoken form (e.g., "x squared plus y equals 5" instead of "x² + y = 5") |
| - Numbers: Keep all numbers and convert them to natural speech format |
| - Important punctuation: Keep periods, commas, question marks, exclamation marks for natural speech flow |
| 3. Remove markdown formatting, asterisks, underscores, brackets, etc. that are not needed for speech |
| 4. Keep all meaningful words, letters, and essential content |
| 5. Make the text natural, clear, and easy to read aloud |
| 6. Do NOT remove any actual content or meaning from the text |
| 7. Convert any special formatting to natural spoken language |
| |
| Return ONLY the cleaned text, nothing else.""" |
|
|
| user_prompt = f"Clean this text for text-to-speech:\n\n{text}" |
| |
| |
| cleaned_text = chat(system_prompt, user_prompt) |
| |
| |
| cleaned_text = cleaned_text.strip() |
| |
| |
| |
| if "cleaned text" in cleaned_text.lower() or "here's" in cleaned_text.lower(): |
| |
| lines = cleaned_text.split("\n") |
| |
| cleaned_lines = [] |
| skip_next = False |
| for line in lines: |
| line_lower = line.lower().strip() |
| if any(marker in line_lower for marker in ["cleaned text", "here's", "here is", "result:", "output:"]): |
| skip_next = True |
| continue |
| if skip_next and not line.strip(): |
| continue |
| skip_next = False |
| cleaned_lines.append(line) |
| if cleaned_lines: |
| cleaned_text = "\n".join(cleaned_lines).strip() |
| |
| return cleaned_text |
| |
| except Exception as e: |
| |
| return text |
|
|
|
|
| def text_to_speech(text: str, voice: str = "en-US-AriaNeural") -> np.ndarray: |
| """ |
| Convert text to speech using VibeVoice (preferred) or edge-tts (fallback). |
| |
| Args: |
| text: Text to convert to speech |
| voice: Voice to use (for edge-tts fallback, default: en-US-AriaNeural) |
| |
| Returns: |
| Audio array as numpy array (mono, 16kHz) |
| """ |
| |
| audio = text_to_speech_vibevoice(text) |
| if audio is not None: |
| return audio |
| |
| |
| if not TTS_AVAILABLE: |
| return np.zeros(16000, dtype=np.float32) |
| |
| try: |
| |
| |
| return asyncio.run(text_to_speech_edge_tts(text, voice)) |
| except Exception: |
| |
| return np.zeros(16000, dtype=np.float32) |
|
|
| INITIALIZATION_STATUS = {"model_loaded": True, "error": None} |
| class GenerateRequest(BaseModel): |
| audio_data: str = Field( |
| ..., |
| description="", |
| ) |
| sample_rate: int = Field(..., description="") |
|
|
| class GenerateResponse(BaseModel): |
| audio_data: str = Field(..., description="") |
|
|
| app = FastAPI(title="V1", version="0.1") |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| def b64(b64: str) -> np.ndarray: |
| raw = base64.b64decode(b64) |
| return np.load(io.BytesIO(raw), allow_pickle=False) |
| def ab64(arr: np.ndarray, sr: int) -> str: |
| buf = io.BytesIO() |
| resampled = librosa.resample(arr, orig_sr=16000, target_sr=sr) |
| np.save(buf, resampled.astype(np.float32)) |
| return base64.b64encode(buf.getvalue()).decode() |
| @app.get("/api/v1/health") |
| def health_check(): |
| status = { |
| "status": "healthy", |
| "model_loaded": INITIALIZATION_STATUS["model_loaded"], |
| "error": INITIALIZATION_STATUS["error"], |
| } |
| return status |
| @app.post("/api/v1/v2v", response_model=GenerateResponse) |
| def generate_audio(req: GenerateRequest): |
| """Voice-to-voice endpoint - returns audio response. |
| |
| Process: |
| 1. Convert input audio to text (v2t) |
| 2. Generate text response (LLM) |
| 3. Clean response text for TTS |
| 4. Convert cleaned text to speech (t2v) using VibeVoice or edge-tts |
| 5. Return generated audio |
| """ |
| if not VIBEVOICE_AVAILABLE and not TTS_AVAILABLE: |
| raise HTTPException( |
| status_code=500, |
| detail="TTS functionality not available. Please install VibeVoice or edge-tts" |
| ) |
| |
| try: |
| |
| audio_np = b64(req.audio_data) |
| |
| |
| if audio_np.ndim == 1: |
| audio_np = audio_np.reshape(1, -1) |
| elif audio_np.ndim == 2 and audio_np.shape[0] > 1: |
| |
| audio_np = audio_np.mean(axis=0, keepdims=True) |
|
|
| |
| user_message = gt(audio_np, req.sample_rate) |
| |
| if not user_message: |
| |
| silence = np.zeros(16000, dtype=np.float32) |
| return GenerateResponse(audio_data=ab64(silence, req.sample_rate)) |
| |
| |
| system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
| system_prompt += "\n\n" + """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. |
| If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" |
| |
| response_text = chat(system_prompt, user_message) |
| |
| |
| cleaned_response_text = clean_v2t_response_for_v2v(response_text) |
| |
| |
| cleaned_response_text = clean_text_for_tts_with_llm(cleaned_response_text) |
| |
| |
| try: |
| audio_output = text_to_speech(cleaned_response_text) |
| encoded_audio = ab64(audio_output, req.sample_rate) |
| except Exception as tts_error: |
| |
| silence = np.zeros(16000, dtype=np.float32) |
| encoded_audio = ab64(silence, req.sample_rate) |
| |
| return GenerateResponse(audio_data=encoded_audio) |
| |
| except Exception as e: |
| traceback.print_exc() |
| |
| try: |
| silence = np.zeros(16000, dtype=np.float32) |
| encoded_audio = ab64(silence, req.sample_rate) |
| return GenerateResponse(audio_data=encoded_audio) |
| except: |
| |
| raise HTTPException(status_code=500, detail=f"{e}") |
|
|
| @app.post("/api/v1/v2t") |
| def generate_text(req: GenerateRequest): |
| audio_np = b64(req.audio_data) |
| if audio_np.ndim == 1: |
| audio_np = audio_np.reshape(1, -1) |
| try: |
| text = gt(audio_np, req.sample_rate) |
| system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
| response_text = chat(system_prompt, user_prompt=text) |
| except Exception as e: |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"{e}") |
| return {"text": response_text} |
|
|
| if __name__ == "__main__": |
| uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=False) |
|
|