import torch import asyncio import websockets import json import threading import numpy as np from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline import subprocess import logging import time from concurrent.futures import ThreadPoolExecutor import struct import re # Arabic number conversion imports try: from pyarabic.number import text2number arabic_numbers_available = True print("Arabic number conversion available") except ImportError: arabic_numbers_available = False print("pyarabic not available - install with: pip install pyarabic") print("Arabic numbers will not be converted to digits") # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def convert_arabic_numbers_in_sentence(sentence: str) -> str: """ Replace Arabic number words in a sentence with digits, preserving all other words and punctuation. Handles common spelling variants and zero explicitly. """ try: print("Fxn called--------------") # --- Normalization step --- replacements = { "اربعة": "أربعة", "اربع": "أربع", "اثنين": "اثنان", "اتنين": "اثنان", # Egyptian variant "ثلاث": "ثلاثة", "خمس": "خمسة", "ست": "ستة", "سبع": "سبعة", "ثمان": "ثمانية", "تسع": "تسعة", "عشر": "عشرة", } for wrong, correct in replacements.items(): sentence = re.sub(rf"\b{wrong}\b", correct, sentence) # --- Split by whitespace but keep spaces --- words = re.split(r'(\s+)', sentence) converted_words = [] for word in words: stripped = word.strip() if not stripped: # skip spaces converted_words.append(word) continue try: num = text2number(stripped) # Accept valid numbers, including zero explicitly if isinstance(num, int): if num != 0 or stripped == "صفر": converted_words.append(str(num)) else: converted_words.append(word) else: converted_words.append(word) except Exception: converted_words.append(word) return ''.join(converted_words) except Exception as e: logger.warning(f"Error converting Arabic numbers: {e}") return sentence # Try to install flash-attn if not available try: import flash_attn use_flash_attn = True except ImportError: print("Flash attention not available, using standard attention") use_flash_attn = False try: subprocess.run( "pip install websockets", shell=True, check=False ) subprocess.run( "pip install flash-attn --no-build-isolation", shell=True, check=False ) except: pass device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 MODEL_NAME = "openai/whisper-large-v3-turbo" print(f"Using device: {device}") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") # Model initialization with fallback for attention implementation try: if use_flash_attn and torch.cuda.is_available(): model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" ) else: model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) except Exception as e: print(f"Error loading model with flash attention: {e}") model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_NAME, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(MODEL_NAME) tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME) # Thread pool for processing audio executor = ThreadPoolExecutor(max_workers=4) class JambonzAudioBuffer: def __init__(self, sample_rate=8000, chunk_duration=1.0): self.sample_rate = sample_rate self.chunk_duration = chunk_duration self.chunk_samples = int(chunk_duration * sample_rate) self.buffer = np.array([], dtype=np.float32) self.lock = threading.Lock() self.total_audio = np.array([], dtype=np.float32) # Voice Activity Detection (simple energy-based) self.silence_threshold = 0.01 self.min_speech_samples = int(0.3 * sample_rate) # 300ms minimum speech def add_audio(self, audio_data): with self.lock: self.buffer = np.concatenate([self.buffer, audio_data]) self.total_audio = np.concatenate([self.total_audio, audio_data]) def has_chunk_ready(self): with self.lock: return len(self.buffer) >= self.chunk_samples def is_speech(self, audio_chunk): """Simple VAD based on energy""" if len(audio_chunk) < self.min_speech_samples: return False energy = np.mean(np.abs(audio_chunk)) return energy > self.silence_threshold def get_chunk_for_processing(self): """Get audio chunk for processing - but don't remove it from buffer for interim results""" with self.lock: if len(self.buffer) < self.chunk_samples: return None # For interim results, we want to trigger processing but keep accumulating audio # So we just return a signal that we have enough audio, but don't consume it return np.array([1]) # Return a dummy array to signal chunk is ready def get_all_audio(self): """Get all accumulated audio for final transcription""" with self.lock: return self.total_audio.copy() def clear(self): with self.lock: self.buffer = np.array([], dtype=np.float32) self.total_audio = np.array([], dtype=np.float32) def linear16_to_audio(audio_bytes, sample_rate=8000): """Convert LINEAR16 PCM bytes to numpy array (jambonz format)""" try: # jambonz sends LINEAR16 PCM at 8kHz audio_array = np.frombuffer(audio_bytes, dtype=np.int16) # Convert to float32 and normalize audio_array = audio_array.astype(np.float32) / 32768.0 return audio_array except Exception as e: logger.error(f"Error converting LINEAR16 to audio: {e}") return np.array([], dtype=np.float32) def resample_audio(audio_data, source_rate, target_rate): """Simple resampling from 8kHz to 16kHz for Whisper""" if source_rate == target_rate: return audio_data # Simple linear interpolation resampling ratio = target_rate / source_rate indices = np.arange(0, len(audio_data), 1/ratio) indices = indices[indices < len(audio_data)] resampled = np.interp(indices, np.arange(len(audio_data)), audio_data) # Ensure proper float32 dtype for consistency return resampled.astype(np.float32) def transcribe_chunk_direct(audio_data, source_sample_rate=8000, target_sample_rate=16000): """Transcribe audio chunk using model's generate method directly""" try: if len(audio_data) == 0: return "" # Resample from 8kHz to 16kHz for Whisper resampled_audio = resample_audio(audio_data, source_sample_rate, target_sample_rate) # Ensure minimum length for Whisper min_samples = int(0.1 * target_sample_rate) # 100ms minimum if len(resampled_audio) < min_samples: return "" start_time = time.time() # Prepare input features with proper dtype input_features = processor( resampled_audio, sampling_rate=target_sample_rate, return_tensors="pt" ).input_features # Ensure correct dtype and device input_features = input_features.to(device=device, dtype=torch_dtype) # Create attention mask to avoid warnings attention_mask = torch.ones( input_features.shape[:-1], dtype=torch.long, device=device ) # Generate transcription using model directly with torch.no_grad(): predicted_ids = model.generate( input_features, attention_mask=attention_mask, max_new_tokens=128, do_sample=False, temperature=0.0, num_beams=1, language="ar", task="transcribe", pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode the transcription transcription = tokenizer.batch_decode( predicted_ids, skip_special_tokens=True )[0].strip() end_time = time.time() logger.info(f"Direct transcription completed in {end_time - start_time:.2f}s: '{transcription}'") return transcription except Exception as e: logger.error(f"Error during direct transcription: {e}") return "" class JambonzSTTHandler: def __init__(self, websocket): self.websocket = websocket self.audio_buffer = None self.config = {} self.running = True self.transcription_task = None self.full_transcript = "" self.last_partial = "" # Auto-final detection variables self.interim_count = 0 self.last_interim_time = None self.silence_timeout = 1.5 # 3 seconds of silence to trigger final self.min_interim_count = 1 # Minimum interim results before considering final self.auto_final_task = None self.accumulated_transcript = "" self.final_sent = False async def start_processing(self, start_message): """Initialize with start message from jambonz""" self.config = { "language": start_message.get("language", "ar-EG"), "format": start_message.get("format", "raw"), "encoding": start_message.get("encoding", "LINEAR16"), "sample_rate": start_message.get("sampleRateHz", 8000), "interim_results": start_message.get("interimResults", True), "options": start_message.get("options", {}) } logger.info(f"STT session started with config: {self.config}") # Initialize audio buffer self.audio_buffer = JambonzAudioBuffer( sample_rate=self.config["sample_rate"], chunk_duration=1.0 # Process every 1 second ) # Reset auto-final detection variables self.interim_count = 0 self.last_interim_time = None self.accumulated_transcript = "" self.final_sent = False # Start background transcription task self.transcription_task = asyncio.create_task(self._process_audio_chunks()) # Start auto-final detection task self.auto_final_task = asyncio.create_task(self._monitor_for_auto_final()) async def stop_processing(self): """Stop processing and send final transcription""" self.running = False # Cancel background tasks if self.transcription_task: self.transcription_task.cancel() try: await self.transcription_task except asyncio.CancelledError: pass if self.auto_final_task: self.auto_final_task.cancel() try: await self.auto_final_task except asyncio.CancelledError: pass # Send final transcription if not already sent if not self.final_sent and self.accumulated_transcript.strip(): await self.send_transcription(self.accumulated_transcript, is_final=True) # Also process any remaining audio for comprehensive final transcription if self.audio_buffer: all_audio = self.audio_buffer.get_all_audio() if len(all_audio) > 0 and not self.final_sent: loop = asyncio.get_event_loop() final_transcription = await loop.run_in_executor( executor, transcribe_chunk_direct, all_audio, self.config["sample_rate"] ) if final_transcription.strip(): # Send comprehensive final transcription await self.send_transcription(final_transcription, is_final=True) logger.info("STT session ended") async def add_audio_data(self, audio_bytes): """Add audio data to buffer""" if self.audio_buffer: audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"]) self.audio_buffer.add_audio(audio_data) async def _process_audio_chunks(self): """Process audio chunks for interim results""" while self.running and self.config.get("interim_results", False): try: if self.audio_buffer and self.audio_buffer.has_chunk_ready(): chunk_signal = self.audio_buffer.get_chunk_for_processing() if chunk_signal is not None: # Get all accumulated audio so far for complete transcription all_audio = self.audio_buffer.get_all_audio() # Only process if we have actual speech content if len(all_audio) > 0 and self.audio_buffer.is_speech(all_audio[-self.audio_buffer.chunk_samples:]): # Run transcription on all accumulated audio loop = asyncio.get_event_loop() transcription = await loop.run_in_executor( executor, transcribe_chunk_direct, all_audio, self.config["sample_rate"] ) if transcription.strip() and transcription != self.last_partial: self.last_partial = transcription self.accumulated_transcript = transcription # Update accumulated transcript self.interim_count += 1 self.last_interim_time = time.time() # Send interim result await self.send_transcription(transcription, is_final=False) logger.info(f"Interim #{self.interim_count}: '{transcription}'") # Small delay to prevent excessive processing await asyncio.sleep(0.1) except Exception as e: logger.error(f"Error in chunk processing: {e}") await asyncio.sleep(1) async def _monitor_for_auto_final(self): """Monitor for auto-final conditions: 3 seconds silence after 3+ interim results""" while self.running: try: current_time = time.time() # Check if we should send auto-final transcription if (self.interim_count >= self.min_interim_count and self.last_interim_time is not None and (current_time - self.last_interim_time) >= self.silence_timeout and not self.final_sent and self.accumulated_transcript.strip()): logger.info(f"Auto-final triggered: {self.interim_count} interim results, " f"{current_time - self.last_interim_time:.1f}s silence") # Send the accumulated transcript as final await self.send_transcription(self.accumulated_transcript, is_final=True) self.final_sent = True # Reset counters for potential next utterance self.interim_count = 0 self.last_interim_time = None self.accumulated_transcript = "" # Check every 0.5 seconds await asyncio.sleep(0.5) except Exception as e: logger.error(f"Error in auto-final monitoring: {e}") await asyncio.sleep(1) # async def send_transcription(self, text, is_final=False, confidence=0.9): # """Send transcription in jambonz format with Arabic number conversion""" # try: # # Convert Arabic numbers to digits before sending # original_text = text # converted_text = convert_arabic_numbers_in_sentence(text) # # Log the conversion if numbers were found and converted # if original_text != converted_text: # logger.info(f"Arabic numbers converted: '{original_text}' -> '{converted_text}'") # message = { # "type": "transcription", # "is_final": is_final, # "alternatives": [ # { # "transcript": converted_text, # "confidence": confidence # } # ], # "language": self.config.get("language", "ar-EG"), # "channel": 1 # } # await self.websocket.send(json.dumps(message)) # logger.info(f"Sent {'FINAL' if is_final else 'interim'} transcription: '{converted_text}'") # if is_final: # self.final_sent = True # except Exception as e: # logger.error(f"Error sending transcription: {e}") async def send_transcription(self, text, is_final=False, confidence=0.9): """Send transcription in jambonz format with Arabic number conversion, only for final results""" try: if not is_final: # Do nothing for interim results logger.debug("Skipping interim transcription (not final).") return # Convert Arabic numbers only for final transcripts original_text = text converted_text = convert_arabic_numbers_in_sentence(text) # Log the conversion if numbers were found and converted if original_text != converted_text: logger.info(f"Arabic numbers converted: '{original_text}' -> '{converted_text}'") message = { "type": "transcription", "is_final": True, "alternatives": [ { "transcript": converted_text, "confidence": confidence } ], "language": self.config.get("language", "ar-EG"), "channel": 1 } # Send only final messages await self.websocket.send(json.dumps(message)) logger.info(f"Sent FINAL transcription: '{converted_text}'") self.final_sent = True except Exception as e: logger.error(f"Error sending transcription: {e}") async def send_error(self, error_message): """Send error message in jambonz format""" try: message = { "type": "error", "error": error_message } await self.websocket.send(json.dumps(message)) logger.error(f"Sent error: {error_message}") except Exception as e: logger.error(f"Error sending error message: {e}") async def handle_jambonz_websocket(websocket): """Handle jambonz WebSocket connections""" client_id = f"jambonz_{id(websocket)}" logger.info(f"New jambonz connection: {client_id}") handler = JambonzSTTHandler(websocket) try: async for message in websocket: try: if isinstance(message, str): # Handle JSON control messages data = json.loads(message) message_type = data.get("type") if message_type == "start": logger.info(f"Received start message: {data}") await handler.start_processing(data) elif message_type == "stop": logger.info("Received stop message") await handler.stop_processing() # Close websocket after final transcription await websocket.close(code=1000, reason="Session completed") break else: logger.warning(f"Unknown message type: {message_type}") await handler.send_error(f"Unknown message type: {message_type}") else: # Handle binary audio data (LINEAR16 PCM) if handler.audio_buffer is None: await handler.send_error("Received audio before start message") continue await handler.add_audio_data(message) except json.JSONDecodeError as e: logger.error(f"JSON decode error: {e}") await handler.send_error(f"Invalid JSON: {str(e)}") except Exception as e: logger.error(f"Error processing message: {e}") await handler.send_error(f"Processing error: {str(e)}") except websockets.exceptions.ConnectionClosed: logger.info(f"jambonz connection closed: {client_id}") except Exception as e: logger.error(f"jambonz WebSocket error: {e}") try: await handler.send_error(str(e)) except: pass finally: if handler.running: await handler.stop_processing() logger.info(f"jambonz connection ended: {client_id}") async def main(): """Start the jambonz STT WebSocket server""" logger.info("Starting Jambonz Custom STT WebSocket server on port 3006...") # Start WebSocket server server = await websockets.serve( handle_jambonz_websocket, "0.0.0.0", 3006, ping_interval=20, ping_timeout=10, close_timeout=10 ) logger.info("Jambonz Custom STT WebSocket server started on ws://0.0.0.0:3006") logger.info("Ready to handle jambonz STT requests") logger.info("- Expects LINEAR16 PCM audio at 8kHz") logger.info("- Supports interim results with auto-final detection") logger.info("- Auto-final: 3+ interim results + 1.3s silence") logger.info("- Resamples to 16kHz for Whisper processing") logger.info("- Converts Arabic numbers to digits before sending") # Wait for the server to close await server.wait_closed() if __name__ == "__main__": print("=" * 60) print("Jambonz Custom STT Server with Whisper + Arabic Numbers") print("=" * 60) print(f"Model: {MODEL_NAME}") print(f"Device: {device}") print("WebSocket Port: 3006") print("Protocol: jambonz STT API") print("Audio Format: LINEAR16 PCM @ 8kHz") print("Auto-Final: 2+ speech activities + 1.3s silence") print("Arabic Numbers: Converted to digits in FINAL transcriptions only") print("Interim Results: DISABLED (final transcription only)") if arabic_numbers_available: print("✓ pyarabic library available for number conversion") else: print("✗ pyarabic library not available - install with: pip install pyarabic") print("=" * 60) try: asyncio.run(main()) except KeyboardInterrupt: print("\nShutting down server...") except Exception as e: print(f"Server error: {e}")