# import torch # import asyncio # import websockets # import json # import threading # import numpy as np # from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline , WhisperForConditionalGeneration, WhisperProcessor # import subprocess # import logging # import time # from concurrent.futures import ThreadPoolExecutor # import struct # import re # 3 - 10 - 2025 import torch import asyncio import websockets import json import threading import numpy as np from transformers import pipeline import subprocess import logging import time from concurrent.futures import ThreadPoolExecutor import re import tempfile import os import soundfile as sf from pathlib import Path # --- Denoiser added --- try: import noisereduce as nr denoiser_available = True print("Denoiser available (using noisereduce)") except ImportError: denoiser_available = False print("noisereduce not available - install with: pip install noisereduce") ############################################################################################## # Arabic number conversion imports try: from pyarabic.number import text2number arabic_numbers_available = True print("Arabic number conversion available") except ImportError: arabic_numbers_available = False print("pyarabic not available - install with: pip install pyarabic") print("Arabic numbers will not be converted to digits") # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 3 - 10 - 2025 # def denoise_audio(audio_data, sample_rate=16000): # """Apply noise reduction to audio using noisereduce.""" # if not denoiser_available or len(audio_data) == 0: # return audio_data # try: # reduced = nr.reduce_noise(y=audio_data, sr=sample_rate) # return reduced.astype(np.float32) # except Exception as e: # logger.warning(f"Denoiser failed: {e}") # return audio_data ############################################################################################# def convert_arabic_numbers_in_sentence(sentence: str) -> str: """ Replace Arabic number words in a sentence with digits, preserving all other words and punctuation. Handles common spelling variants and zero explicitly. """ try: print("Fxn called--------------") # --- Normalization step --- replacements = { "اربعة": "أربعة", "اربع": "أربع", "اثنين": "اثنان", "اتنين": "اثنان", # Egyptian variant "ثلاث": "ثلاثة", "خمس": "خمسة", "ست": "ستة", "سبع": "سبعة", "ثمان": "ثمانية", "تسع": "تسعة", "عشر": "عشرة", } for wrong, correct in replacements.items(): sentence = re.sub(rf"\b{wrong}\b", correct, sentence) # --- Split by whitespace but keep spaces --- words = re.split(r'(\s+)', sentence) converted_words = [] for word in words: stripped = word.strip() if not stripped: # skip spaces converted_words.append(word) continue try: num = text2number(stripped) # Accept valid numbers, including zero explicitly if isinstance(num, int): if num != 0 or stripped == "صفر": converted_words.append(str(num)) else: converted_words.append(word) else: converted_words.append(word) except Exception: converted_words.append(word) return ''.join(converted_words) except Exception as e: logger.warning(f"Error converting Arabic numbers: {e}") return sentence # Try to install flash-attn if not available try: import flash_attn use_flash_attn = True except ImportError: print("Flash attention not available, using standard attention") use_flash_attn = False try: subprocess.run( "pip install websockets", shell=True, check=False ) subprocess.run( "pip install flash-attn --no-build-isolation", shell=True, check=False ) except: pass device = "cuda" if torch.cuda.is_available() else "cpu" # --- Facebook Denoiser added --- try: import torchaudio from denoiser import pretrained # Load DNS64 pretrained model (auto-downloads if not cached) denoiser_model = pretrained.dns64().to(device) denoiser_model.eval() denoiser_available = True print("facebook/denoiser loaded successfully") except ImportError as e: denoiser_available = False print("facebook/denoiser not available - install with: pip install denoiser torchaudio") denoiser_model = None torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 MODEL_NAME = "alaatiger989/FT_Arabic_Whisper_V1_1"#"openai/whisper-large-v3-turbo" print(f"Using device: {device}") print(f"CUDA available: {torch.cuda.is_available()}") MODEL_NAME = "speechbrain/asr-whisper-large-v2-commonvoice-ar" # Replace your pipeline definition from speechbrain.inference.ASR import WhisperASR # Load the SpeechBrain model model = WhisperASR.from_hparams( source="speechbrain/asr-whisper-large-v2-commonvoice-ar", savedir="pretrained_models/asr-whisper-large-v2-commonvoice-ar", run_opts={"device": "cuda"} if torch.cuda.is_available() else {} ) def denoise_audio(audio_data, sample_rate=16000): """Apply denoising using facebook/denoiser pretrained model.""" if denoiser_model is None or len(audio_data) == 0: return audio_data try: audio_tensor = torch.tensor(audio_data, dtype=torch.float32, device=device).unsqueeze(0) with torch.no_grad(): denoised_tensor = denoiser_model(audio_tensor)[0] # no sample_rate arg return denoised_tensor.squeeze().cpu().numpy().astype("float32") except Exception as e: print(f"[WARN] Denoiser failed: {e}") return audio_data # Thread pool for processing audio executor = ThreadPoolExecutor(max_workers=4) class JambonzAudioBuffer: def __init__(self, sample_rate=8000, chunk_duration=1.0): self.sample_rate = sample_rate self.chunk_duration = chunk_duration self.chunk_samples = int(chunk_duration * sample_rate) self.buffer = np.array([], dtype=np.float32) self.lock = threading.Lock() self.total_audio = np.array([], dtype=np.float32) # Voice Activity Detection (simple energy-based) self.silence_threshold = 0.01 self.min_speech_samples = int(0.3 * sample_rate) # 300ms minimum speech def add_audio(self, audio_data): with self.lock: self.buffer = np.concatenate([self.buffer, audio_data]) self.total_audio = np.concatenate([self.total_audio, audio_data]) def has_chunk_ready(self): with self.lock: return len(self.buffer) >= self.chunk_samples def is_speech(self, audio_chunk): """Simple VAD based on energy""" if len(audio_chunk) < self.min_speech_samples: return False energy = np.mean(np.abs(audio_chunk)) return energy > self.silence_threshold def get_chunk_for_processing(self): """Get audio chunk for processing - but don't remove it from buffer for interim results""" with self.lock: if len(self.buffer) < self.chunk_samples: return None # For interim results, we want to trigger processing but keep accumulating audio # So we just return a signal that we have enough audio, but don't consume it return np.array([1]) # Return a dummy array to signal chunk is ready def get_all_audio(self): """Get all accumulated audio for final transcription""" with self.lock: return self.total_audio.copy() def clear(self): with self.lock: self.buffer = np.array([], dtype=np.float32) self.total_audio = np.array([], dtype=np.float32) def linear16_to_audio(audio_bytes, sample_rate=8000): """Convert LINEAR16 PCM bytes to numpy array (jambonz format)""" try: # jambonz sends LINEAR16 PCM at 8kHz audio_array = np.frombuffer(audio_bytes, dtype=np.int16) # Convert to float32 and normalize audio_array = audio_array.astype(np.float32) / 32768.0 return audio_array except Exception as e: logger.error(f"Error converting LINEAR16 to audio: {e}") return np.array([], dtype=np.float32) def resample_audio(audio_data, source_rate, target_rate): """Simple resampling from 8kHz to 16kHz""" if source_rate == target_rate: return audio_data ratio = target_rate / source_rate indices = np.arange(0, len(audio_data), 1/ratio) indices = indices[indices < len(audio_data)] resampled = np.interp(indices, np.arange(len(audio_data)), audio_data) return resampled.astype(np.float32) import os import tempfile import soundfile as sf import logging logger = logging.getLogger(__name__) from pathlib import Path import uuid import shutil # Project-level temp folder PROJECT_DIR = Path(__file__).parent.resolve() AUDIO_TMP_DIR = PROJECT_DIR / "temp_audio" AUDIO_TMP_DIR.mkdir(exist_ok=True) def transcribe_chunk_direct(audio_data, source_sample_rate=8000, target_sample_rate=16000): try: if len(audio_data) == 0: return "" # Step 1: Resample resampled_audio = resample_audio(audio_data, source_sample_rate, target_sample_rate) # Step 2: Denoise resampled_audio = denoise_audio(resampled_audio) # Step 3: Check minimum length (100ms) min_samples = int(0.1 * target_sample_rate) if len(resampled_audio) < min_samples: return "" # Step 4: Convert numpy -> torch tensor waveform = torch.tensor(resampled_audio, dtype=torch.float32).unsqueeze(0) # [1, T] # Step 5: Create wav_lens (normalized length) wav_lens = torch.tensor([1.0]) # full length, no padding # Step 6: Transcribe words, tokens = model.transcribe_batch(waveform, wav_lens) # Step 7: Convert list of words to a sentence transcription = " ".join(words[0]) if words and len(words) > 0 else "" logger.info(f"SpeechBrain transcription: '{transcription}'") return transcription.strip() except Exception as e: logger.error(f"Error during SpeechBrain transcription: {e}") return "" # def transcribe_chunk_direct(audio_data, source_sample_rate=8000, target_sample_rate=16000): # """Transcribe audio chunk using model's generate method directly""" # try: # if len(audio_data) == 0: # return "" # # Resample from 8kHz to 16kHz for Whisper # resampled_audio = resample_audio(audio_data, source_sample_rate, target_sample_rate) # # Ensure minimum length for Whisper # min_samples = int(0.1 * target_sample_rate) # 100ms minimum # if len(resampled_audio) < min_samples: # return "" # start_time = time.time() # # Prepare input features with proper dtype # input_features = processor( # resampled_audio, # sampling_rate=target_sample_rate, # return_tensors="pt" # ).input_features # # Ensure correct dtype and device # input_features = input_features.to(device=device, dtype=torch_dtype) # # Create attention mask to avoid warnings # attention_mask = torch.ones( # input_features.shape[:-1], # dtype=torch.long, # device=device # ) # # Generate transcription using model directly # with torch.no_grad(): # predicted_ids = model.generate( # input_features, # attention_mask=attention_mask, # max_new_tokens=128, # do_sample=False, # temperature=0.0, # num_beams=1, # language="ar", # task="transcribe", # pad_token_id=tokenizer.pad_token_id, # eos_token_id=tokenizer.eos_token_id # ) # # Decode the transcription # transcription = tokenizer.batch_decode( # predicted_ids, # skip_special_tokens=True # )[0].strip() # end_time = time.time() # logger.info(f"Direct transcription completed in {end_time - start_time:.2f}s: '{transcription}'") # return transcription # except Exception as e: # logger.error(f"Error during direct transcription: {e}") # return "" class JambonzSTTHandler: def __init__(self, websocket): self.websocket = websocket self.audio_buffer = None self.config = {} self.running = True self.transcription_task = None self.full_transcript = "" self.last_partial = "" # Auto-final detection variables self.interim_count = 0 self.last_interim_time = None self.silence_timeout = 1.5 # 3 seconds of silence to trigger final self.min_interim_count = 1 # Minimum interim results before considering final self.auto_final_task = None self.accumulated_transcript = "" self.final_sent = False async def start_processing(self, start_message): """Initialize with start message from jambonz""" self.config = { "language": start_message.get("language", "ar-EG"), "format": start_message.get("format", "raw"), "encoding": start_message.get("encoding", "LINEAR16"), "sample_rate": start_message.get("sampleRateHz", 8000), "interim_results": start_message.get("interimResults", True), "options": start_message.get("options", {}) } logger.info(f"STT session started with config: {self.config}") # Initialize audio buffer self.audio_buffer = JambonzAudioBuffer( sample_rate=self.config["sample_rate"], chunk_duration=1.0 # Process every 1 second ) # Reset auto-final detection variables self.interim_count = 0 self.last_interim_time = None self.accumulated_transcript = "" self.final_sent = False # Start background transcription task self.transcription_task = asyncio.create_task(self._process_audio_chunks()) # Start auto-final detection task self.auto_final_task = asyncio.create_task(self._monitor_for_auto_final()) async def stop_processing(self): """Stop processing and send final transcription""" self.running = False # Cancel background tasks if self.transcription_task: self.transcription_task.cancel() try: await self.transcription_task except asyncio.CancelledError: pass if self.auto_final_task: self.auto_final_task.cancel() try: await self.auto_final_task except asyncio.CancelledError: pass # Send final transcription if not already sent if not self.final_sent and self.accumulated_transcript.strip(): await self.send_transcription(self.accumulated_transcript, is_final=True) # Also process any remaining audio for comprehensive final transcription if self.audio_buffer: all_audio = self.audio_buffer.get_all_audio() if len(all_audio) > 0 and not self.final_sent: loop = asyncio.get_event_loop() final_transcription = await loop.run_in_executor( executor, transcribe_chunk_direct, all_audio, self.config["sample_rate"] ) if final_transcription.strip(): # Send comprehensive final transcription await self.send_transcription(final_transcription, is_final=True) logger.info("STT session ended") async def add_audio_data(self, audio_bytes): """Add audio data to buffer""" if self.audio_buffer: audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"]) self.audio_buffer.add_audio(audio_data) async def _process_audio_chunks(self): """Process audio chunks for interim results""" while self.running and self.config.get("interim_results", False): try: if self.audio_buffer and self.audio_buffer.has_chunk_ready(): chunk_signal = self.audio_buffer.get_chunk_for_processing() if chunk_signal is not None: # Get all accumulated audio so far for complete transcription all_audio = self.audio_buffer.get_all_audio() # Only process if we have actual speech content if len(all_audio) > 0 and self.audio_buffer.is_speech(all_audio[-self.audio_buffer.chunk_samples:]): # Run transcription on all accumulated audio loop = asyncio.get_event_loop() transcription = await loop.run_in_executor( executor, transcribe_chunk_direct, all_audio, self.config["sample_rate"] ) if transcription.strip() and transcription != self.last_partial: self.last_partial = transcription self.accumulated_transcript = transcription # Update accumulated transcript self.interim_count += 1 self.last_interim_time = time.time() # Send interim result await self.send_transcription(transcription, is_final=False) logger.info(f"Interim #{self.interim_count}: '{transcription}'") # Small delay to prevent excessive processing await asyncio.sleep(0.1) except Exception as e: logger.error(f"Error in chunk processing: {e}") await asyncio.sleep(1) async def _monitor_for_auto_final(self): """Monitor for auto-final conditions: 3 seconds silence after 3+ interim results""" while self.running: try: current_time = time.time() # Check if we should send auto-final transcription if (self.interim_count >= self.min_interim_count and self.last_interim_time is not None and (current_time - self.last_interim_time) >= self.silence_timeout and not self.final_sent and self.accumulated_transcript.strip()): logger.info(f"Auto-final triggered: {self.interim_count} interim results, " f"{current_time - self.last_interim_time:.1f}s silence") # Send the accumulated transcript as final await self.send_transcription(self.accumulated_transcript, is_final=True) self.final_sent = True # Reset counters for potential next utterance self.interim_count = 0 self.last_interim_time = None self.accumulated_transcript = "" # Check every 0.5 seconds await asyncio.sleep(0.5) except Exception as e: logger.error(f"Error in auto-final monitoring: {e}") await asyncio.sleep(1) # async def send_transcription(self, text, is_final=False, confidence=0.9): # """Send transcription in jambonz format with Arabic number conversion""" # try: # # Convert Arabic numbers to digits before sending # original_text = text # converted_text = convert_arabic_numbers_in_sentence(text) # # Log the conversion if numbers were found and converted # if original_text != converted_text: # logger.info(f"Arabic numbers converted: '{original_text}' -> '{converted_text}'") # message = { # "type": "transcription", # "is_final": is_final, # "alternatives": [ # { # "transcript": converted_text, # "confidence": confidence # } # ], # "language": self.config.get("language", "ar-EG"), # "channel": 1 # } # await self.websocket.send(json.dumps(message)) # logger.info(f"Sent {'FINAL' if is_final else 'interim'} transcription: '{converted_text}'") # if is_final: # self.final_sent = True # except Exception as e: # logger.error(f"Error sending transcription: {e}") async def send_transcription(self, text, is_final=False, confidence=0.9): """Send transcription in jambonz format with Arabic number conversion, only for final results""" try: if not is_final: # Do nothing for interim results logger.debug("Skipping interim transcription (not final).") return # Convert Arabic numbers only for final transcripts original_text = text converted_text = convert_arabic_numbers_in_sentence(text) # Log the conversion if numbers were found and converted if original_text != converted_text: logger.info(f"Arabic numbers converted: '{original_text}' -> '{converted_text}'") message = { "type": "transcription", "is_final": True, "alternatives": [ { "transcript": original_text,#converted_text, "confidence": confidence } ], "language": self.config.get("language", "ar-EG"), "channel": 1 } # Send only final messages await self.websocket.send(json.dumps(message)) logger.info(f"Sent FINAL transcription: '{converted_text}'") self.final_sent = True except Exception as e: logger.error(f"Error sending transcription: {e}") async def send_error(self, error_message): """Send error message in jambonz format""" try: message = { "type": "error", "error": error_message } await self.websocket.send(json.dumps(message)) logger.error(f"Sent error: {error_message}") except Exception as e: logger.error(f"Error sending error message: {e}") async def handle_jambonz_websocket(websocket): """Handle jambonz WebSocket connections""" client_id = f"jambonz_{id(websocket)}" logger.info(f"New jambonz connection: {client_id}") handler = JambonzSTTHandler(websocket) try: async for message in websocket: try: if isinstance(message, str): # Handle JSON control messages data = json.loads(message) message_type = data.get("type") if message_type == "start": logger.info(f"Received start message: {data}") await handler.start_processing(data) elif message_type == "stop": logger.info("Received stop message") await handler.stop_processing() # Close websocket after final transcription await websocket.close(code=1000, reason="Session completed") break else: logger.warning(f"Unknown message type: {message_type}") await handler.send_error(f"Unknown message type: {message_type}") else: # Handle binary audio data (LINEAR16 PCM) if handler.audio_buffer is None: await handler.send_error("Received audio before start message") continue await handler.add_audio_data(message) except json.JSONDecodeError as e: logger.error(f"JSON decode error: {e}") await handler.send_error(f"Invalid JSON: {str(e)}") except Exception as e: logger.error(f"Error processing message: {e}") await handler.send_error(f"Processing error: {str(e)}") except websockets.exceptions.ConnectionClosed: logger.info(f"jambonz connection closed: {client_id}") except Exception as e: logger.error(f"jambonz WebSocket error: {e}") try: await handler.send_error(str(e)) except: pass finally: if handler.running: await handler.stop_processing() logger.info(f"jambonz connection ended: {client_id}") async def main(): """Start the jambonz STT WebSocket server""" logger.info("Starting Jambonz Custom STT WebSocket server on port 3006...") # Start WebSocket server server = await websockets.serve( handle_jambonz_websocket, "0.0.0.0", 3006, ping_interval=20, ping_timeout=10, close_timeout=10 ) logger.info("Jambonz Custom STT WebSocket server started on ws://0.0.0.0:3006") logger.info("Ready to handle jambonz STT requests") logger.info("- Expects LINEAR16 PCM audio at 8kHz") logger.info("- Supports interim results with auto-final detection") logger.info("- Auto-final: 3+ interim results + 1.3s silence") logger.info("- Resamples to 16kHz for Whisper processing") logger.info("- Converts Arabic numbers to digits before sending") # Wait for the server to close await server.wait_closed() if __name__ == "__main__": print("=" * 60) print("Jambonz Custom STT Server with Whisper + Arabic Numbers") print("=" * 60) print(f"Model: {MODEL_NAME}") print(f"Device: {device}") print("WebSocket Port: 3006") print("Protocol: jambonz STT API") print("Audio Format: LINEAR16 PCM @ 8kHz") print("Auto-Final: 2+ speech activities + 1.3s silence") print("Arabic Numbers: Converted to digits in FINAL transcriptions only") print("Interim Results: DISABLED (final transcription only)") if arabic_numbers_available: print("✓ pyarabic library available for number conversion") else: print("✗ pyarabic library not available - install with: pip install pyarabic") print("=" * 60) try: asyncio.run(main()) except KeyboardInterrupt: print("\nShutting down server...") except Exception as e: print(f"Server error: {e}")