|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import asyncio
|
| import websockets
|
| import json
|
| import threading
|
| import numpy as np
|
| from transformers import pipeline
|
| import subprocess
|
| import logging
|
| import time
|
| from concurrent.futures import ThreadPoolExecutor
|
| import re
|
| import tempfile
|
| import os
|
| import soundfile as sf
|
| from pathlib import Path
|
|
|
| try:
|
| import noisereduce as nr
|
| denoiser_available = True
|
| print("Denoiser available (using noisereduce)")
|
| except ImportError:
|
| denoiser_available = False
|
| print("noisereduce not available - install with: pip install noisereduce")
|
|
|
|
|
| try:
|
| from pyarabic.number import text2number
|
| arabic_numbers_available = True
|
| print("Arabic number conversion available")
|
| except ImportError:
|
| arabic_numbers_available = False
|
| print("pyarabic not available - install with: pip install pyarabic")
|
| print("Arabic numbers will not be converted to digits")
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def convert_arabic_numbers_in_sentence(sentence: str) -> str:
|
| """
|
| Replace Arabic number words in a sentence with digits,
|
| preserving all other words and punctuation.
|
| Handles common spelling variants and zero explicitly.
|
| """
|
| try:
|
| print("Fxn called--------------")
|
|
|
|
|
| replacements = {
|
| "اربعة": "أربعة",
|
| "اربع": "أربع",
|
| "اثنين": "اثنان",
|
| "اتنين": "اثنان",
|
| "ثلاث": "ثلاثة",
|
| "خمس": "خمسة",
|
| "ست": "ستة",
|
| "سبع": "سبعة",
|
| "ثمان": "ثمانية",
|
| "تسع": "تسعة",
|
| "عشر": "عشرة",
|
| }
|
| for wrong, correct in replacements.items():
|
| sentence = re.sub(rf"\b{wrong}\b", correct, sentence)
|
|
|
|
|
| words = re.split(r'(\s+)', sentence)
|
| converted_words = []
|
|
|
| for word in words:
|
| stripped = word.strip()
|
| if not stripped:
|
| converted_words.append(word)
|
| continue
|
|
|
| try:
|
| num = text2number(stripped)
|
|
|
|
|
| if isinstance(num, int):
|
| if num != 0 or stripped == "صفر":
|
| converted_words.append(str(num))
|
| else:
|
| converted_words.append(word)
|
| else:
|
| converted_words.append(word)
|
|
|
| except Exception:
|
| converted_words.append(word)
|
|
|
| return ''.join(converted_words)
|
|
|
| except Exception as e:
|
| logger.warning(f"Error converting Arabic numbers: {e}")
|
| return sentence
|
|
|
|
|
|
|
| try:
|
| import flash_attn
|
| use_flash_attn = True
|
| except ImportError:
|
| print("Flash attention not available, using standard attention")
|
| use_flash_attn = False
|
| try:
|
| subprocess.run(
|
| "pip install websockets",
|
| shell=True,
|
| check=False
|
| )
|
| subprocess.run(
|
| "pip install flash-attn --no-build-isolation",
|
| shell=True,
|
| check=False
|
| )
|
| except:
|
| pass
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| try:
|
| import torchaudio
|
| from denoiser import pretrained
|
|
|
| denoiser_model = pretrained.dns64().to(device)
|
| denoiser_model.eval()
|
| denoiser_available = True
|
| print("facebook/denoiser loaded successfully")
|
| except ImportError as e:
|
| denoiser_available = False
|
| print("facebook/denoiser not available - install with: pip install denoiser torchaudio")
|
| denoiser_model = None
|
|
|
|
|
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| MODEL_NAME = "alaatiger989/FT_Arabic_Whisper_V1_1"
|
|
|
| print(f"Using device: {device}")
|
| print(f"CUDA available: {torch.cuda.is_available()}")
|
| MODEL_NAME = "speechbrain/asr-whisper-large-v2-commonvoice-ar"
|
|
|
| from speechbrain.inference.ASR import WhisperASR
|
|
|
|
|
| model = WhisperASR.from_hparams(
|
| source="speechbrain/asr-whisper-large-v2-commonvoice-ar",
|
| savedir="pretrained_models/asr-whisper-large-v2-commonvoice-ar",
|
| run_opts={"device": "cuda"} if torch.cuda.is_available() else {}
|
| )
|
|
|
|
|
| def denoise_audio(audio_data, sample_rate=16000):
|
| """Apply denoising using facebook/denoiser pretrained model."""
|
| if denoiser_model is None or len(audio_data) == 0:
|
| return audio_data
|
| try:
|
| audio_tensor = torch.tensor(audio_data, dtype=torch.float32, device=device).unsqueeze(0)
|
| with torch.no_grad():
|
| denoised_tensor = denoiser_model(audio_tensor)[0]
|
| return denoised_tensor.squeeze().cpu().numpy().astype("float32")
|
| except Exception as e:
|
| print(f"[WARN] Denoiser failed: {e}")
|
| return audio_data
|
|
|
|
|
| executor = ThreadPoolExecutor(max_workers=4)
|
|
|
| class JambonzAudioBuffer:
|
| def __init__(self, sample_rate=8000, chunk_duration=1.0):
|
| self.sample_rate = sample_rate
|
| self.chunk_duration = chunk_duration
|
| self.chunk_samples = int(chunk_duration * sample_rate)
|
|
|
| self.buffer = np.array([], dtype=np.float32)
|
| self.lock = threading.Lock()
|
| self.total_audio = np.array([], dtype=np.float32)
|
|
|
|
|
| self.silence_threshold = 0.01
|
| self.min_speech_samples = int(0.3 * sample_rate)
|
|
|
| def add_audio(self, audio_data):
|
| with self.lock:
|
| self.buffer = np.concatenate([self.buffer, audio_data])
|
| self.total_audio = np.concatenate([self.total_audio, audio_data])
|
|
|
| def has_chunk_ready(self):
|
| with self.lock:
|
| return len(self.buffer) >= self.chunk_samples
|
|
|
| def is_speech(self, audio_chunk):
|
| """Simple VAD based on energy"""
|
| if len(audio_chunk) < self.min_speech_samples:
|
| return False
|
| energy = np.mean(np.abs(audio_chunk))
|
| return energy > self.silence_threshold
|
|
|
| def get_chunk_for_processing(self):
|
| """Get audio chunk for processing - but don't remove it from buffer for interim results"""
|
| with self.lock:
|
| if len(self.buffer) < self.chunk_samples:
|
| return None
|
|
|
|
|
|
|
| return np.array([1])
|
|
|
| def get_all_audio(self):
|
| """Get all accumulated audio for final transcription"""
|
| with self.lock:
|
| return self.total_audio.copy()
|
|
|
| def clear(self):
|
| with self.lock:
|
| self.buffer = np.array([], dtype=np.float32)
|
| self.total_audio = np.array([], dtype=np.float32)
|
|
|
| def linear16_to_audio(audio_bytes, sample_rate=8000):
|
| """Convert LINEAR16 PCM bytes to numpy array (jambonz format)"""
|
| try:
|
|
|
| audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
|
|
|
| audio_array = audio_array.astype(np.float32) / 32768.0
|
| return audio_array
|
| except Exception as e:
|
| logger.error(f"Error converting LINEAR16 to audio: {e}")
|
| return np.array([], dtype=np.float32)
|
|
|
| def resample_audio(audio_data, source_rate, target_rate):
|
| """Simple resampling from 8kHz to 16kHz"""
|
| if source_rate == target_rate:
|
| return audio_data
|
| ratio = target_rate / source_rate
|
| indices = np.arange(0, len(audio_data), 1/ratio)
|
| indices = indices[indices < len(audio_data)]
|
| resampled = np.interp(indices, np.arange(len(audio_data)), audio_data)
|
| return resampled.astype(np.float32)
|
|
|
|
|
| import os
|
| import tempfile
|
| import soundfile as sf
|
| import logging
|
|
|
| logger = logging.getLogger(__name__)
|
| from pathlib import Path
|
| import uuid
|
| import shutil
|
|
|
| PROJECT_DIR = Path(__file__).parent.resolve()
|
| AUDIO_TMP_DIR = PROJECT_DIR / "temp_audio"
|
| AUDIO_TMP_DIR.mkdir(exist_ok=True)
|
| def transcribe_chunk_direct(audio_data, source_sample_rate=8000, target_sample_rate=16000):
|
| try:
|
| if len(audio_data) == 0:
|
| return ""
|
|
|
|
|
| resampled_audio = resample_audio(audio_data, source_sample_rate, target_sample_rate)
|
|
|
|
|
| resampled_audio = denoise_audio(resampled_audio)
|
|
|
|
|
| min_samples = int(0.1 * target_sample_rate)
|
| if len(resampled_audio) < min_samples:
|
| return ""
|
|
|
|
|
| waveform = torch.tensor(resampled_audio, dtype=torch.float32).unsqueeze(0)
|
|
|
|
|
| wav_lens = torch.tensor([1.0])
|
|
|
|
|
| words, tokens = model.transcribe_batch(waveform, wav_lens)
|
|
|
|
|
| transcription = " ".join(words[0]) if words and len(words) > 0 else ""
|
|
|
| logger.info(f"SpeechBrain transcription: '{transcription}'")
|
| return transcription.strip()
|
|
|
| except Exception as e:
|
| logger.error(f"Error during SpeechBrain transcription: {e}")
|
| return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class JambonzSTTHandler:
|
| def __init__(self, websocket):
|
| self.websocket = websocket
|
| self.audio_buffer = None
|
| self.config = {}
|
| self.running = True
|
| self.transcription_task = None
|
| self.full_transcript = ""
|
| self.last_partial = ""
|
|
|
|
|
| self.interim_count = 0
|
| self.last_interim_time = None
|
| self.silence_timeout = 1.5
|
| self.min_interim_count = 1
|
| self.auto_final_task = None
|
| self.accumulated_transcript = ""
|
| self.final_sent = False
|
|
|
| async def start_processing(self, start_message):
|
| """Initialize with start message from jambonz"""
|
| self.config = {
|
| "language": start_message.get("language", "ar-EG"),
|
| "format": start_message.get("format", "raw"),
|
| "encoding": start_message.get("encoding", "LINEAR16"),
|
| "sample_rate": start_message.get("sampleRateHz", 8000),
|
| "interim_results": start_message.get("interimResults", True),
|
| "options": start_message.get("options", {})
|
| }
|
|
|
| logger.info(f"STT session started with config: {self.config}")
|
|
|
|
|
| self.audio_buffer = JambonzAudioBuffer(
|
| sample_rate=self.config["sample_rate"],
|
| chunk_duration=1.0
|
| )
|
|
|
|
|
| self.interim_count = 0
|
| self.last_interim_time = None
|
| self.accumulated_transcript = ""
|
| self.final_sent = False
|
|
|
|
|
| self.transcription_task = asyncio.create_task(self._process_audio_chunks())
|
|
|
|
|
| self.auto_final_task = asyncio.create_task(self._monitor_for_auto_final())
|
|
|
| async def stop_processing(self):
|
| """Stop processing and send final transcription"""
|
| self.running = False
|
|
|
|
|
| if self.transcription_task:
|
| self.transcription_task.cancel()
|
| try:
|
| await self.transcription_task
|
| except asyncio.CancelledError:
|
| pass
|
|
|
| if self.auto_final_task:
|
| self.auto_final_task.cancel()
|
| try:
|
| await self.auto_final_task
|
| except asyncio.CancelledError:
|
| pass
|
|
|
|
|
| if not self.final_sent and self.accumulated_transcript.strip():
|
| await self.send_transcription(self.accumulated_transcript, is_final=True)
|
|
|
|
|
| if self.audio_buffer:
|
| all_audio = self.audio_buffer.get_all_audio()
|
| if len(all_audio) > 0 and not self.final_sent:
|
| loop = asyncio.get_event_loop()
|
| final_transcription = await loop.run_in_executor(
|
| executor,
|
| transcribe_chunk_direct,
|
| all_audio,
|
| self.config["sample_rate"]
|
| )
|
|
|
| if final_transcription.strip():
|
|
|
| await self.send_transcription(final_transcription, is_final=True)
|
|
|
| logger.info("STT session ended")
|
|
|
| async def add_audio_data(self, audio_bytes):
|
| """Add audio data to buffer"""
|
| if self.audio_buffer:
|
| audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"])
|
| self.audio_buffer.add_audio(audio_data)
|
|
|
| async def _process_audio_chunks(self):
|
| """Process audio chunks for interim results"""
|
| while self.running and self.config.get("interim_results", False):
|
| try:
|
| if self.audio_buffer and self.audio_buffer.has_chunk_ready():
|
| chunk_signal = self.audio_buffer.get_chunk_for_processing()
|
| if chunk_signal is not None:
|
|
|
| all_audio = self.audio_buffer.get_all_audio()
|
|
|
|
|
| if len(all_audio) > 0 and self.audio_buffer.is_speech(all_audio[-self.audio_buffer.chunk_samples:]):
|
|
|
| loop = asyncio.get_event_loop()
|
| transcription = await loop.run_in_executor(
|
| executor,
|
| transcribe_chunk_direct,
|
| all_audio,
|
| self.config["sample_rate"]
|
| )
|
|
|
| if transcription.strip() and transcription != self.last_partial:
|
| self.last_partial = transcription
|
| self.accumulated_transcript = transcription
|
| self.interim_count += 1
|
| self.last_interim_time = time.time()
|
|
|
|
|
| await self.send_transcription(transcription, is_final=False)
|
|
|
| logger.info(f"Interim #{self.interim_count}: '{transcription}'")
|
|
|
|
|
| await asyncio.sleep(0.1)
|
|
|
| except Exception as e:
|
| logger.error(f"Error in chunk processing: {e}")
|
| await asyncio.sleep(1)
|
|
|
| async def _monitor_for_auto_final(self):
|
| """Monitor for auto-final conditions: 3 seconds silence after 3+ interim results"""
|
| while self.running:
|
| try:
|
| current_time = time.time()
|
|
|
|
|
| if (self.interim_count >= self.min_interim_count and
|
| self.last_interim_time is not None and
|
| (current_time - self.last_interim_time) >= self.silence_timeout and
|
| not self.final_sent and
|
| self.accumulated_transcript.strip()):
|
|
|
| logger.info(f"Auto-final triggered: {self.interim_count} interim results, "
|
| f"{current_time - self.last_interim_time:.1f}s silence")
|
|
|
|
|
| await self.send_transcription(self.accumulated_transcript, is_final=True)
|
| self.final_sent = True
|
|
|
|
|
| self.interim_count = 0
|
| self.last_interim_time = None
|
| self.accumulated_transcript = ""
|
|
|
|
|
| await asyncio.sleep(0.5)
|
|
|
| except Exception as e:
|
| logger.error(f"Error in auto-final monitoring: {e}")
|
| await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| async def send_transcription(self, text, is_final=False, confidence=0.9):
|
| """Send transcription in jambonz format with Arabic number conversion, only for final results"""
|
| try:
|
| if not is_final:
|
|
|
| logger.debug("Skipping interim transcription (not final).")
|
| return
|
|
|
|
|
| original_text = text
|
| converted_text = convert_arabic_numbers_in_sentence(text)
|
|
|
|
|
| if original_text != converted_text:
|
| logger.info(f"Arabic numbers converted: '{original_text}' -> '{converted_text}'")
|
|
|
| message = {
|
| "type": "transcription",
|
| "is_final": True,
|
| "alternatives": [
|
| {
|
| "transcript": original_text,
|
| "confidence": confidence
|
| }
|
| ],
|
| "language": self.config.get("language", "ar-EG"),
|
| "channel": 1
|
| }
|
|
|
|
|
| await self.websocket.send(json.dumps(message))
|
| logger.info(f"Sent FINAL transcription: '{converted_text}'")
|
|
|
| self.final_sent = True
|
|
|
| except Exception as e:
|
| logger.error(f"Error sending transcription: {e}")
|
|
|
|
|
|
|
|
|
| async def send_error(self, error_message):
|
| """Send error message in jambonz format"""
|
| try:
|
| message = {
|
| "type": "error",
|
| "error": error_message
|
| }
|
| await self.websocket.send(json.dumps(message))
|
| logger.error(f"Sent error: {error_message}")
|
| except Exception as e:
|
| logger.error(f"Error sending error message: {e}")
|
|
|
| async def handle_jambonz_websocket(websocket):
|
| """Handle jambonz WebSocket connections"""
|
|
|
| client_id = f"jambonz_{id(websocket)}"
|
| logger.info(f"New jambonz connection: {client_id}")
|
|
|
| handler = JambonzSTTHandler(websocket)
|
|
|
| try:
|
| async for message in websocket:
|
| try:
|
| if isinstance(message, str):
|
|
|
| data = json.loads(message)
|
| message_type = data.get("type")
|
|
|
| if message_type == "start":
|
| logger.info(f"Received start message: {data}")
|
| await handler.start_processing(data)
|
|
|
| elif message_type == "stop":
|
| logger.info("Received stop message")
|
| await handler.stop_processing()
|
|
|
| await websocket.close(code=1000, reason="Session completed")
|
| break
|
|
|
| else:
|
| logger.warning(f"Unknown message type: {message_type}")
|
| await handler.send_error(f"Unknown message type: {message_type}")
|
|
|
| else:
|
|
|
| if handler.audio_buffer is None:
|
| await handler.send_error("Received audio before start message")
|
| continue
|
|
|
| await handler.add_audio_data(message)
|
|
|
| except json.JSONDecodeError as e:
|
| logger.error(f"JSON decode error: {e}")
|
| await handler.send_error(f"Invalid JSON: {str(e)}")
|
| except Exception as e:
|
| logger.error(f"Error processing message: {e}")
|
| await handler.send_error(f"Processing error: {str(e)}")
|
|
|
| except websockets.exceptions.ConnectionClosed:
|
| logger.info(f"jambonz connection closed: {client_id}")
|
| except Exception as e:
|
| logger.error(f"jambonz WebSocket error: {e}")
|
| try:
|
| await handler.send_error(str(e))
|
| except:
|
| pass
|
| finally:
|
| if handler.running:
|
| await handler.stop_processing()
|
| logger.info(f"jambonz connection ended: {client_id}")
|
|
|
| async def main():
|
| """Start the jambonz STT WebSocket server"""
|
| logger.info("Starting Jambonz Custom STT WebSocket server on port 3006...")
|
|
|
|
|
| server = await websockets.serve(
|
| handle_jambonz_websocket,
|
| "0.0.0.0",
|
| 3006,
|
| ping_interval=20,
|
| ping_timeout=10,
|
| close_timeout=10
|
| )
|
|
|
| logger.info("Jambonz Custom STT WebSocket server started on ws://0.0.0.0:3006")
|
| logger.info("Ready to handle jambonz STT requests")
|
| logger.info("- Expects LINEAR16 PCM audio at 8kHz")
|
| logger.info("- Supports interim results with auto-final detection")
|
| logger.info("- Auto-final: 3+ interim results + 1.3s silence")
|
| logger.info("- Resamples to 16kHz for Whisper processing")
|
| logger.info("- Converts Arabic numbers to digits before sending")
|
|
|
|
|
| await server.wait_closed()
|
|
|
| if __name__ == "__main__":
|
| print("=" * 60)
|
| print("Jambonz Custom STT Server with Whisper + Arabic Numbers")
|
| print("=" * 60)
|
| print(f"Model: {MODEL_NAME}")
|
| print(f"Device: {device}")
|
| print("WebSocket Port: 3006")
|
| print("Protocol: jambonz STT API")
|
| print("Audio Format: LINEAR16 PCM @ 8kHz")
|
| print("Auto-Final: 2+ speech activities + 1.3s silence")
|
| print("Arabic Numbers: Converted to digits in FINAL transcriptions only")
|
| print("Interim Results: DISABLED (final transcription only)")
|
| if arabic_numbers_available:
|
| print("✓ pyarabic library available for number conversion")
|
| else:
|
| print("✗ pyarabic library not available - install with: pip install pyarabic")
|
| print("=" * 60)
|
|
|
| try:
|
| asyncio.run(main())
|
| except KeyboardInterrupt:
|
| print("\nShutting down server...")
|
| except Exception as e:
|
| print(f"Server error: {e}")
|
|
|