Jambonz_impl / aqib-nemo-asr.py
alaatiger989's picture
Upload folder using huggingface_hub
04c4cd1 verified
import asyncio
import websockets
import json
import threading
import numpy as np
import logging
import time
import tempfile
import os
import re
from concurrent.futures import ThreadPoolExecutor
import nemo.collections.asr as nemo_asr
import soundfile as sf
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ===== Arabic number mapping (expanded) =====
arabic_numbers = {
# Basic digits
"صفر": "0", "زيرو": "0", "٠": "0","زيو": "0","زير": "0",
"واحد": "1", "واحدة": "1", "١": "1",
"اتنين": "2", "اثنين": "2", "إثنين": "2", "اثنان": "2", "إثنان": "2", "٢": "2",
"تلاتة": "3", "ثلاثة": "3", "٣": "3",
"اربعة": "4", "أربعة": "4", "٤": "4",
"خمسة": "5", "٥": "5",
"ستة": "6", "٦": "6",
"سبعة": "7", "٧": "7",
"تمانية": "8", "ثمانية": "8", "٨": "8",
"تسعة": "9", "٩": "9",
# Teens
"عشرة": "10", "١٠": "10",
"حداشر": "11", "احد عشر": "11","احداشر": "11",
"اتناشر": "12", "اثنا عشر": "12",
"تلتاشر": "13", "ثلاثة عشر": "13",
"اربعتاشر": "14", "أربعة عشر": "14",
"خمستاشر": "15", "خمسة عشر": "15",
"ستاشر": "16", "ستة عشر": "16",
"سبعتاشر": "17", "سبعة عشر": "17",
"طمنتاشر": "18", "ثمانية عشر": "18",
"تسعتاشر": "19", "تسعة عشر": "19",
# Tens
"عشرين": "20", "٢٠": "20",
"تلاتين": "30", "ثلاثين": "30", "٣٠": "30",
"اربعين": "40", "أربعين": "40", "٤٠": "40",
"خمسين": "50", "٥٠": "50",
"ستين": "60", "٦٠": "60",
"سبعين": "70", "٧٠": "70",
"تمانين": "80", "ثمانين": "80", "٨٠": "80","تمانون": "80","ثمانون": "80",
"تسعين": "90", "٩٠": "90",
# Hundreds
"مية": "100", "مائة": "100", "مئة": "100", "١٠٠": "100",
"ميتين": "200", "مائتين": "200",
"تلاتمية": "300", "ثلاثمائة": "300",
"اربعمية": "400", "أربعمائة": "400",
"خمسمية": "500", "خمسمائة": "500",
"ستمية": "600", "ستمائة": "600",
"سبعمية": "700", "سبعمائة": "700",
"تمانمية": "800", "ثمانمائة": "800",
"تسعمية": "900", "تسعمائة": "900",
# Thousands
"ألف": "1000", "الف": "1000", "١٠٠٠": "1000",
"ألفين": "2000", "الفين": "2000",
"تلات تلاف": "3000", "ثلاثة آلاف": "3000",
"اربعة آلاف": "4000", "أربعة آلاف": "4000",
"خمسة آلاف": "5000",
"ستة آلاف": "6000",
"سبعة آلاف": "7000",
"تمانية آلاف": "8000", "ثمانية آلاف": "8000",
"تسعة آلاف": "9000",
# Large numbers
"عشرة آلاف": "10000",
"مية ألف": "100000", "مائة ألف": "100000",
"مليون": "1000000", "١٠٠٠٠٠٠": "1000000",
"ملايين": "1000000",
"مليار": "1000000000", "١٠٠٠٠٠٠٠٠٠": "1000000000"
}
def replace_arabic_numbers(text: str) -> str:
for word, digit in arabic_numbers.items():
text = re.sub(rf"\b{word}\b", digit, text)
return text
# Global NeMo model
asr_model = None
def initialize_nemo_model():
"""Initialize NeMo FastConformer model"""
global asr_model
logger.info("Loading NeMo FastConformer Arabic ASR model...")
# Model path - adjust this to your model location
model_path = os.getenv(
"NEMO_MODEL_PATH",
"/path/to/stt_ar_fastconformer_hybrid_large_pcd_v1.0.nemo" # Update this path
)
if not os.path.exists(model_path):
logger.error(f"Model not found at: {model_path}")
logger.info("Please download the model from: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_ar_fastconformer_hybrid_large_pcd")
raise FileNotFoundError(f"NeMo model not found: {model_path}")
try:
asr_model = nemo_asr.models.EncDecCTCModel.restore_from(model_path)
logger.info("NeMo FastConformer model loaded successfully")
# Set model to eval mode for inference
asr_model.eval()
except Exception as e:
logger.error(f"Failed to load NeMo model: {e}")
raise
# Initialize model on startup
initialize_nemo_model()
# Thread pool for processing
executor = ThreadPoolExecutor(max_workers=4)
class JambonzAudioBuffer:
def __init__(self, sample_rate=8000, chunk_duration=1.0):
self.sample_rate = sample_rate
self.chunk_duration = chunk_duration
self.chunk_samples = int(chunk_duration * sample_rate)
self.buffer = np.array([], dtype=np.float32)
self.lock = threading.Lock()
self.total_audio = np.array([], dtype=np.float32)
# Voice Activity Detection
self.silence_threshold = 0.05
self.min_speech_samples = int(0.5 * sample_rate)
def add_audio(self, audio_data):
with self.lock:
self.buffer = np.concatenate([self.buffer, audio_data])
self.total_audio = np.concatenate([self.total_audio, audio_data])
def has_chunk_ready(self):
with self.lock:
return len(self.buffer) >= self.chunk_samples
def is_speech(self, audio_chunk):
"""Simple VAD based on energy"""
if len(audio_chunk) < self.min_speech_samples:
return False
energy = np.mean(np.abs(audio_chunk))
return energy > self.silence_threshold
def get_chunk_for_processing(self):
"""Get audio chunk for processing"""
with self.lock:
if len(self.buffer) < self.chunk_samples:
return None
return np.array([1]) # Signal that chunk is ready
def get_all_audio(self):
"""Get all accumulated audio"""
with self.lock:
return self.total_audio.copy()
def clear(self):
with self.lock:
self.buffer = np.array([], dtype=np.float32)
self.total_audio = np.array([], dtype=np.float32)
def reset_for_new_segment(self):
"""Reset buffers for new transcription segment"""
with self.lock:
self.buffer = np.array([], dtype=np.float32)
self.total_audio = np.array([], dtype=np.float32)
def linear16_to_audio(audio_bytes, sample_rate=8000):
"""Convert LINEAR16 PCM bytes to numpy array"""
try:
audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
audio_array = audio_array.astype(np.float32) / 32768.0
return audio_array
except Exception as e:
logger.error(f"Error converting LINEAR16 to audio: {e}")
return np.array([], dtype=np.float32)
def resample_audio(audio_data, source_rate, target_rate):
"""Resample audio to target sample rate"""
if source_rate == target_rate:
return audio_data
if source_rate == 8000 and target_rate == 16000:
# Simple 2x upsampling for common case
upsampled = np.repeat(audio_data, 2)
return upsampled.astype(np.float32)
# Fallback: Linear interpolation resampling
ratio = target_rate / source_rate
indices = np.arange(0, len(audio_data), 1/ratio)
indices = indices[indices < len(audio_data)]
resampled = np.interp(indices, np.arange(len(audio_data)), audio_data)
return resampled.astype(np.float32)
def transcribe_with_nemo(audio_data, source_sample_rate=8000, target_sample_rate=16000):
"""Transcribe audio using NeMo FastConformer"""
try:
if len(audio_data) == 0:
return ""
# Resample to 16kHz (NeMo models typically expect 16kHz)
resampled_audio = resample_audio(audio_data, 8000, 16000)
# Skip very short audio
min_samples = int(0.3 * 16000)
if len(resampled_audio) < min_samples:
return ""
start_time = time.time()
# Save audio to temporary file (NeMo expects file path)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
# Write audio as WAV file
sf.write(tmp_file.name, resampled_audio, target_sample_rate)
tmp_path = tmp_file.name
try:
# Transcribe with NeMo
result = asr_model.transcribe([tmp_path])
# Debug logging to understand result format
logger.info(f"NeMo result type: {type(result)}")
if result and len(result) > 0:
logger.info(f"First result type: {type(result[0])}")
logger.info(f"First result content: {result[0]}")
if result and len(result) > 0:
# Handle different NeMo result formats
if hasattr(result[0], 'text'):
# If result has .text attribute (newer NeMo versions)
raw_text = result[0].text
logger.info(f"Using .text attribute: {raw_text}")
elif isinstance(result[0], str):
# If result is directly a string
raw_text = result[0]
logger.info(f"Using direct string: {raw_text}")
else:
# If result is some other format, convert to string
raw_text = str(result[0])
logger.info(f"Using str() conversion: {raw_text}")
# Ensure raw_text is a string before processing
if not isinstance(raw_text, str):
raw_text = str(raw_text)
# Only process if we have actual text content
if raw_text and raw_text.strip():
# Convert Arabic numbers to digits
logger.info(f"before sending to FXN--- {raw_text}")
cleaned_text = replace_arabic_numbers(raw_text)
logger.info(f"after FXN--- {cleaned_text}")
end_time = time.time()
if cleaned_text.strip():
logger.info(f"NeMo transcription: '{cleaned_text}' (processed in {end_time - start_time:.2f}s)")
return cleaned_text.strip()
else:
logger.info("No transcription text found")
return ""
else:
logger.info("No results from NeMo transcription")
return ""
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception as e:
logger.error(f"Error during NeMo transcription: {e}")
return ""
class JambonzSTTHandler:
def __init__(self, websocket):
self.websocket = websocket
self.audio_buffer = None
self.config = {}
self.running = False
self.transcription_task = None
# Auto-final detection variables
self.interim_count = 0
self.last_interim_time = None
self.silence_timeout = 2.0
self.min_interim_count = 2
self.auto_final_task = None
self.accumulated_transcript = ""
self.final_sent = False
self.segment_number = 0
self.last_partial = ""
# Processing tracking
self.processing_count = 0
async def start_processing(self, start_message):
"""Initialize with start message from jambonz"""
self.config = {
"language": start_message.get("language", "ar-EG"),
"format": start_message.get("format", "raw"),
"encoding": start_message.get("encoding", "LINEAR16"),
"sample_rate": start_message.get("sampleRateHz", 8000),
"interim_results": True, # Always enable for internal processing
"options": start_message.get("options", {})
}
logger.info(f"NeMo STT session started with config: {self.config}")
# Initialize audio buffer
self.audio_buffer = JambonzAudioBuffer(
sample_rate=self.config["sample_rate"],
chunk_duration=1.0 # 1 second chunks for NeMo
)
# Reset session variables
self.running = True
self.interim_count = 0
self.last_interim_time = None
self.accumulated_transcript = ""
self.final_sent = False
self.segment_number = 0
self.processing_count = 0
self.last_partial = ""
# Start background transcription task
self.transcription_task = asyncio.create_task(self._process_audio_chunks())
# Start auto-final detection task
self.auto_final_task = asyncio.create_task(self._monitor_for_auto_final())
async def stop_processing(self):
"""Stop current processing session"""
logger.info("Stopping NeMo STT session...")
self.running = False
# Cancel background tasks
for task in [self.transcription_task, self.auto_final_task]:
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Send final transcription if not already sent
if not self.final_sent and self.accumulated_transcript.strip():
await self.send_transcription(self.accumulated_transcript, is_final=True)
# Process any remaining audio for comprehensive final transcription
if self.audio_buffer:
all_audio = self.audio_buffer.get_all_audio()
if len(all_audio) > 0 and not self.final_sent:
loop = asyncio.get_event_loop()
final_transcription = await loop.run_in_executor(
executor,
transcribe_with_nemo,
all_audio,
self.config["sample_rate"]
)
if final_transcription.strip():
await self.send_transcription(final_transcription, is_final=True)
# Clear audio buffer
if self.audio_buffer:
self.audio_buffer.clear()
logger.info("NeMo STT session stopped")
async def start_new_segment(self):
"""Start a new transcription segment"""
self.segment_number += 1
self.interim_count = 0
self.last_interim_time = None
self.accumulated_transcript = ""
self.final_sent = False
self.last_partial = ""
self.processing_count = 0
if self.audio_buffer:
self.audio_buffer.reset_for_new_segment()
logger.info(f"Started new transcription segment #{self.segment_number}")
async def add_audio_data(self, audio_bytes):
"""Add audio data to buffer"""
if self.audio_buffer and self.running:
audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"])
self.audio_buffer.add_audio(audio_data)
async def _process_audio_chunks(self):
"""Process audio chunks for interim results"""
while self.running:
try:
if self.audio_buffer and self.audio_buffer.has_chunk_ready():
chunk_signal = self.audio_buffer.get_chunk_for_processing()
if chunk_signal is not None:
all_audio = self.audio_buffer.get_all_audio()
if len(all_audio) > 0 and self.audio_buffer.is_speech(all_audio[-self.audio_buffer.chunk_samples:]):
loop = asyncio.get_event_loop()
transcription = await loop.run_in_executor(
executor,
transcribe_with_nemo,
all_audio,
self.config["sample_rate"]
)
if transcription.strip():
self.processing_count += 1
self.accumulated_transcript = transcription
if transcription != self.last_partial or self.interim_count == 0:
self.last_partial = transcription
self.interim_count += 1
self.last_interim_time = time.time()
logger.info(f"Updated interim_count to {self.interim_count} for transcript: '{transcription}'")
else:
self.last_interim_time = time.time()
await asyncio.sleep(0.1) # Check every 100ms
except Exception as e:
logger.error(f"Error in chunk processing: {e}")
await asyncio.sleep(0.1)
async def _monitor_for_auto_final(self):
"""Monitor for auto-final conditions"""
while self.running:
try:
current_time = time.time()
if (self.interim_count >= self.min_interim_count and
self.last_interim_time is not None and
(current_time - self.last_interim_time) >= self.silence_timeout and
not self.final_sent and
self.accumulated_transcript.strip()):
logger.info(f"Auto-final triggered for segment #{self.segment_number}")
await self.send_transcription(self.accumulated_transcript, is_final=True)
await self.start_new_segment()
await asyncio.sleep(0.5) # Check every 500ms
except Exception as e:
logger.error(f"Error in auto-final monitoring: {e}")
await asyncio.sleep(0.5)
async def send_transcription(self, text, is_final=True, confidence=0.9):
"""Send transcription in jambonz format"""
try:
message = {
"type": "transcription",
"is_final": True, # Always send as final
"alternatives": [
{
"transcript": text,
"confidence": confidence
}
],
"language": self.config.get("language", "ar-EG"),
"channel": 1
}
await self.websocket.send(json.dumps(message))
self.final_sent = True
logger.info(f"Sent FINAL transcription to Jambonz: '{text}'")
except Exception as e:
logger.error(f"Error sending transcription: {e}")
async def send_error(self, error_message):
"""Send error message in jambonz format"""
try:
message = {
"type": "error",
"error": error_message
}
await self.websocket.send(json.dumps(message))
logger.error(f"Sent error: {error_message}")
except Exception as e:
logger.error(f"Error sending error message: {e}")
async def handle_jambonz_websocket(websocket):
"""Handle jambonz WebSocket connections"""
client_id = f"jambonz_{id(websocket)}"
logger.info(f"New NeMo jambonz connection: {client_id}")
handler = JambonzSTTHandler(websocket)
try:
async for message in websocket:
try:
if isinstance(message, str):
data = json.loads(message)
message_type = data.get("type")
if message_type == "start":
logger.info(f"Received start message: {data}")
await handler.start_processing(data)
elif message_type == "stop":
logger.info("Received stop message - closing WebSocket")
await handler.stop_processing()
await websocket.close(code=1000, reason="Session stopped by client")
break
else:
logger.warning(f"Unknown message type: {message_type}")
await handler.send_error(f"Unknown message type: {message_type}")
else:
# Handle binary audio data
if not handler.running or handler.audio_buffer is None:
logger.warning("Received audio data outside of active session")
await handler.send_error("Received audio before start message or after stop")
continue
await handler.add_audio_data(message)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
await handler.send_error(f"Invalid JSON: {str(e)}")
except Exception as e:
logger.error(f"Error processing message: {e}")
await handler.send_error(f"Processing error: {str(e)}")
except websockets.exceptions.ConnectionClosed:
logger.info(f"NeMo jambonz connection closed: {client_id}")
except Exception as e:
logger.error(f"NeMo jambonz WebSocket error: {e}")
try:
await handler.send_error(str(e))
except:
pass
finally:
if handler.running:
await handler.stop_processing()
logger.info(f"NeMo jambonz connection ended: {client_id}")
async def main():
"""Start the NeMo jambonz STT WebSocket server"""
logger.info("Starting NeMo Jambonz STT WebSocket server on port 3007...")
# Start WebSocket server
server = await websockets.serve(
handle_jambonz_websocket,
"0.0.0.0",
3007,
ping_interval=20,
ping_timeout=10,
close_timeout=10
)
logger.info("NeMo Jambonz STT WebSocket server started on ws://0.0.0.0:3007")
logger.info("Ready to handle jambonz STT requests with NeMo FastConformer")
logger.info("FEATURES:")
logger.info("- Arabic ASR using NeMo FastConformer model")
logger.info("- Arabic number word to digit conversion")
logger.info("- Continuous transcription with segmentation")
logger.info("- Voice Activity Detection")
# Wait for the server to close
await server.wait_closed()
if __name__ == "__main__":
print("=" * 80)
print("NeMo FastConformer Jambonz STT Server")
print("=" * 80)
print("Model: NeMo FastConformer Arabic ASR")
print("WebSocket Port: 3007")
print("Protocol: jambonz STT API")
print("Audio Format: LINEAR16 PCM @ 8kHz → 16kHz")
print("Language: Arabic with number conversion")
print("=" * 80)
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nShutting down NeMo server...")
except Exception as e:
print(f"Server error: {e}")