Jambonz_impl / speech_brain_whisper_denoiser.py
alaatiger989's picture
Upload folder using huggingface_hub
04c4cd1 verified
# import torch
# import asyncio
# import websockets
# import json
# import threading
# import numpy as np
# from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline , WhisperForConditionalGeneration, WhisperProcessor
# import subprocess
# import logging
# import time
# from concurrent.futures import ThreadPoolExecutor
# import struct
# import re
# 3 - 10 - 2025
import torch
import asyncio
import websockets
import json
import threading
import numpy as np
from transformers import pipeline
import subprocess
import logging
import time
from concurrent.futures import ThreadPoolExecutor
import re
import tempfile
import os
import soundfile as sf
from pathlib import Path
# --- Denoiser added ---
try:
import noisereduce as nr
denoiser_available = True
print("Denoiser available (using noisereduce)")
except ImportError:
denoiser_available = False
print("noisereduce not available - install with: pip install noisereduce")
##############################################################################################
# Arabic number conversion imports
try:
from pyarabic.number import text2number
arabic_numbers_available = True
print("Arabic number conversion available")
except ImportError:
arabic_numbers_available = False
print("pyarabic not available - install with: pip install pyarabic")
print("Arabic numbers will not be converted to digits")
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 3 - 10 - 2025
# def denoise_audio(audio_data, sample_rate=16000):
# """Apply noise reduction to audio using noisereduce."""
# if not denoiser_available or len(audio_data) == 0:
# return audio_data
# try:
# reduced = nr.reduce_noise(y=audio_data, sr=sample_rate)
# return reduced.astype(np.float32)
# except Exception as e:
# logger.warning(f"Denoiser failed: {e}")
# return audio_data
#############################################################################################
def convert_arabic_numbers_in_sentence(sentence: str) -> str:
"""
Replace Arabic number words in a sentence with digits,
preserving all other words and punctuation.
Handles common spelling variants and zero explicitly.
"""
try:
print("Fxn called--------------")
# --- Normalization step ---
replacements = {
"اربعة": "أربعة",
"اربع": "أربع",
"اثنين": "اثنان",
"اتنين": "اثنان", # Egyptian variant
"ثلاث": "ثلاثة",
"خمس": "خمسة",
"ست": "ستة",
"سبع": "سبعة",
"ثمان": "ثمانية",
"تسع": "تسعة",
"عشر": "عشرة",
}
for wrong, correct in replacements.items():
sentence = re.sub(rf"\b{wrong}\b", correct, sentence)
# --- Split by whitespace but keep spaces ---
words = re.split(r'(\s+)', sentence)
converted_words = []
for word in words:
stripped = word.strip()
if not stripped: # skip spaces
converted_words.append(word)
continue
try:
num = text2number(stripped)
# Accept valid numbers, including zero explicitly
if isinstance(num, int):
if num != 0 or stripped == "صفر":
converted_words.append(str(num))
else:
converted_words.append(word)
else:
converted_words.append(word)
except Exception:
converted_words.append(word)
return ''.join(converted_words)
except Exception as e:
logger.warning(f"Error converting Arabic numbers: {e}")
return sentence
# Try to install flash-attn if not available
try:
import flash_attn
use_flash_attn = True
except ImportError:
print("Flash attention not available, using standard attention")
use_flash_attn = False
try:
subprocess.run(
"pip install websockets",
shell=True,
check=False
)
subprocess.run(
"pip install flash-attn --no-build-isolation",
shell=True,
check=False
)
except:
pass
device = "cuda" if torch.cuda.is_available() else "cpu"
# --- Facebook Denoiser added ---
try:
import torchaudio
from denoiser import pretrained
# Load DNS64 pretrained model (auto-downloads if not cached)
denoiser_model = pretrained.dns64().to(device)
denoiser_model.eval()
denoiser_available = True
print("facebook/denoiser loaded successfully")
except ImportError as e:
denoiser_available = False
print("facebook/denoiser not available - install with: pip install denoiser torchaudio")
denoiser_model = None
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
MODEL_NAME = "alaatiger989/FT_Arabic_Whisper_V1_1"#"openai/whisper-large-v3-turbo"
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
MODEL_NAME = "speechbrain/asr-whisper-large-v2-commonvoice-ar"
# Replace your pipeline definition
from speechbrain.inference.ASR import WhisperASR
# Load the SpeechBrain model
model = WhisperASR.from_hparams(
source="speechbrain/asr-whisper-large-v2-commonvoice-ar",
savedir="pretrained_models/asr-whisper-large-v2-commonvoice-ar",
run_opts={"device": "cuda"} if torch.cuda.is_available() else {}
)
def denoise_audio(audio_data, sample_rate=16000):
"""Apply denoising using facebook/denoiser pretrained model."""
if denoiser_model is None or len(audio_data) == 0:
return audio_data
try:
audio_tensor = torch.tensor(audio_data, dtype=torch.float32, device=device).unsqueeze(0)
with torch.no_grad():
denoised_tensor = denoiser_model(audio_tensor)[0] # no sample_rate arg
return denoised_tensor.squeeze().cpu().numpy().astype("float32")
except Exception as e:
print(f"[WARN] Denoiser failed: {e}")
return audio_data
# Thread pool for processing audio
executor = ThreadPoolExecutor(max_workers=4)
class JambonzAudioBuffer:
def __init__(self, sample_rate=8000, chunk_duration=1.0):
self.sample_rate = sample_rate
self.chunk_duration = chunk_duration
self.chunk_samples = int(chunk_duration * sample_rate)
self.buffer = np.array([], dtype=np.float32)
self.lock = threading.Lock()
self.total_audio = np.array([], dtype=np.float32)
# Voice Activity Detection (simple energy-based)
self.silence_threshold = 0.01
self.min_speech_samples = int(0.3 * sample_rate) # 300ms minimum speech
def add_audio(self, audio_data):
with self.lock:
self.buffer = np.concatenate([self.buffer, audio_data])
self.total_audio = np.concatenate([self.total_audio, audio_data])
def has_chunk_ready(self):
with self.lock:
return len(self.buffer) >= self.chunk_samples
def is_speech(self, audio_chunk):
"""Simple VAD based on energy"""
if len(audio_chunk) < self.min_speech_samples:
return False
energy = np.mean(np.abs(audio_chunk))
return energy > self.silence_threshold
def get_chunk_for_processing(self):
"""Get audio chunk for processing - but don't remove it from buffer for interim results"""
with self.lock:
if len(self.buffer) < self.chunk_samples:
return None
# For interim results, we want to trigger processing but keep accumulating audio
# So we just return a signal that we have enough audio, but don't consume it
return np.array([1]) # Return a dummy array to signal chunk is ready
def get_all_audio(self):
"""Get all accumulated audio for final transcription"""
with self.lock:
return self.total_audio.copy()
def clear(self):
with self.lock:
self.buffer = np.array([], dtype=np.float32)
self.total_audio = np.array([], dtype=np.float32)
def linear16_to_audio(audio_bytes, sample_rate=8000):
"""Convert LINEAR16 PCM bytes to numpy array (jambonz format)"""
try:
# jambonz sends LINEAR16 PCM at 8kHz
audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
# Convert to float32 and normalize
audio_array = audio_array.astype(np.float32) / 32768.0
return audio_array
except Exception as e:
logger.error(f"Error converting LINEAR16 to audio: {e}")
return np.array([], dtype=np.float32)
def resample_audio(audio_data, source_rate, target_rate):
"""Simple resampling from 8kHz to 16kHz"""
if source_rate == target_rate:
return audio_data
ratio = target_rate / source_rate
indices = np.arange(0, len(audio_data), 1/ratio)
indices = indices[indices < len(audio_data)]
resampled = np.interp(indices, np.arange(len(audio_data)), audio_data)
return resampled.astype(np.float32)
import os
import tempfile
import soundfile as sf
import logging
logger = logging.getLogger(__name__)
from pathlib import Path
import uuid
import shutil
# Project-level temp folder
PROJECT_DIR = Path(__file__).parent.resolve()
AUDIO_TMP_DIR = PROJECT_DIR / "temp_audio"
AUDIO_TMP_DIR.mkdir(exist_ok=True)
def transcribe_chunk_direct(audio_data, source_sample_rate=8000, target_sample_rate=16000):
try:
if len(audio_data) == 0:
return ""
# Step 1: Resample
resampled_audio = resample_audio(audio_data, source_sample_rate, target_sample_rate)
# Step 2: Denoise
resampled_audio = denoise_audio(resampled_audio)
# Step 3: Check minimum length (100ms)
min_samples = int(0.1 * target_sample_rate)
if len(resampled_audio) < min_samples:
return ""
# Step 4: Convert numpy -> torch tensor
waveform = torch.tensor(resampled_audio, dtype=torch.float32).unsqueeze(0) # [1, T]
# Step 5: Create wav_lens (normalized length)
wav_lens = torch.tensor([1.0]) # full length, no padding
# Step 6: Transcribe
words, tokens = model.transcribe_batch(waveform, wav_lens)
# Step 7: Convert list of words to a sentence
transcription = " ".join(words[0]) if words and len(words) > 0 else ""
logger.info(f"SpeechBrain transcription: '{transcription}'")
return transcription.strip()
except Exception as e:
logger.error(f"Error during SpeechBrain transcription: {e}")
return ""
# def transcribe_chunk_direct(audio_data, source_sample_rate=8000, target_sample_rate=16000):
# """Transcribe audio chunk using model's generate method directly"""
# try:
# if len(audio_data) == 0:
# return ""
# # Resample from 8kHz to 16kHz for Whisper
# resampled_audio = resample_audio(audio_data, source_sample_rate, target_sample_rate)
# # Ensure minimum length for Whisper
# min_samples = int(0.1 * target_sample_rate) # 100ms minimum
# if len(resampled_audio) < min_samples:
# return ""
# start_time = time.time()
# # Prepare input features with proper dtype
# input_features = processor(
# resampled_audio,
# sampling_rate=target_sample_rate,
# return_tensors="pt"
# ).input_features
# # Ensure correct dtype and device
# input_features = input_features.to(device=device, dtype=torch_dtype)
# # Create attention mask to avoid warnings
# attention_mask = torch.ones(
# input_features.shape[:-1],
# dtype=torch.long,
# device=device
# )
# # Generate transcription using model directly
# with torch.no_grad():
# predicted_ids = model.generate(
# input_features,
# attention_mask=attention_mask,
# max_new_tokens=128,
# do_sample=False,
# temperature=0.0,
# num_beams=1,
# language="ar",
# task="transcribe",
# pad_token_id=tokenizer.pad_token_id,
# eos_token_id=tokenizer.eos_token_id
# )
# # Decode the transcription
# transcription = tokenizer.batch_decode(
# predicted_ids,
# skip_special_tokens=True
# )[0].strip()
# end_time = time.time()
# logger.info(f"Direct transcription completed in {end_time - start_time:.2f}s: '{transcription}'")
# return transcription
# except Exception as e:
# logger.error(f"Error during direct transcription: {e}")
# return ""
class JambonzSTTHandler:
def __init__(self, websocket):
self.websocket = websocket
self.audio_buffer = None
self.config = {}
self.running = True
self.transcription_task = None
self.full_transcript = ""
self.last_partial = ""
# Auto-final detection variables
self.interim_count = 0
self.last_interim_time = None
self.silence_timeout = 1.5 # 3 seconds of silence to trigger final
self.min_interim_count = 1 # Minimum interim results before considering final
self.auto_final_task = None
self.accumulated_transcript = ""
self.final_sent = False
async def start_processing(self, start_message):
"""Initialize with start message from jambonz"""
self.config = {
"language": start_message.get("language", "ar-EG"),
"format": start_message.get("format", "raw"),
"encoding": start_message.get("encoding", "LINEAR16"),
"sample_rate": start_message.get("sampleRateHz", 8000),
"interim_results": start_message.get("interimResults", True),
"options": start_message.get("options", {})
}
logger.info(f"STT session started with config: {self.config}")
# Initialize audio buffer
self.audio_buffer = JambonzAudioBuffer(
sample_rate=self.config["sample_rate"],
chunk_duration=1.0 # Process every 1 second
)
# Reset auto-final detection variables
self.interim_count = 0
self.last_interim_time = None
self.accumulated_transcript = ""
self.final_sent = False
# Start background transcription task
self.transcription_task = asyncio.create_task(self._process_audio_chunks())
# Start auto-final detection task
self.auto_final_task = asyncio.create_task(self._monitor_for_auto_final())
async def stop_processing(self):
"""Stop processing and send final transcription"""
self.running = False
# Cancel background tasks
if self.transcription_task:
self.transcription_task.cancel()
try:
await self.transcription_task
except asyncio.CancelledError:
pass
if self.auto_final_task:
self.auto_final_task.cancel()
try:
await self.auto_final_task
except asyncio.CancelledError:
pass
# Send final transcription if not already sent
if not self.final_sent and self.accumulated_transcript.strip():
await self.send_transcription(self.accumulated_transcript, is_final=True)
# Also process any remaining audio for comprehensive final transcription
if self.audio_buffer:
all_audio = self.audio_buffer.get_all_audio()
if len(all_audio) > 0 and not self.final_sent:
loop = asyncio.get_event_loop()
final_transcription = await loop.run_in_executor(
executor,
transcribe_chunk_direct,
all_audio,
self.config["sample_rate"]
)
if final_transcription.strip():
# Send comprehensive final transcription
await self.send_transcription(final_transcription, is_final=True)
logger.info("STT session ended")
async def add_audio_data(self, audio_bytes):
"""Add audio data to buffer"""
if self.audio_buffer:
audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"])
self.audio_buffer.add_audio(audio_data)
async def _process_audio_chunks(self):
"""Process audio chunks for interim results"""
while self.running and self.config.get("interim_results", False):
try:
if self.audio_buffer and self.audio_buffer.has_chunk_ready():
chunk_signal = self.audio_buffer.get_chunk_for_processing()
if chunk_signal is not None:
# Get all accumulated audio so far for complete transcription
all_audio = self.audio_buffer.get_all_audio()
# Only process if we have actual speech content
if len(all_audio) > 0 and self.audio_buffer.is_speech(all_audio[-self.audio_buffer.chunk_samples:]):
# Run transcription on all accumulated audio
loop = asyncio.get_event_loop()
transcription = await loop.run_in_executor(
executor,
transcribe_chunk_direct,
all_audio,
self.config["sample_rate"]
)
if transcription.strip() and transcription != self.last_partial:
self.last_partial = transcription
self.accumulated_transcript = transcription # Update accumulated transcript
self.interim_count += 1
self.last_interim_time = time.time()
# Send interim result
await self.send_transcription(transcription, is_final=False)
logger.info(f"Interim #{self.interim_count}: '{transcription}'")
# Small delay to prevent excessive processing
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error in chunk processing: {e}")
await asyncio.sleep(1)
async def _monitor_for_auto_final(self):
"""Monitor for auto-final conditions: 3 seconds silence after 3+ interim results"""
while self.running:
try:
current_time = time.time()
# Check if we should send auto-final transcription
if (self.interim_count >= self.min_interim_count and
self.last_interim_time is not None and
(current_time - self.last_interim_time) >= self.silence_timeout and
not self.final_sent and
self.accumulated_transcript.strip()):
logger.info(f"Auto-final triggered: {self.interim_count} interim results, "
f"{current_time - self.last_interim_time:.1f}s silence")
# Send the accumulated transcript as final
await self.send_transcription(self.accumulated_transcript, is_final=True)
self.final_sent = True
# Reset counters for potential next utterance
self.interim_count = 0
self.last_interim_time = None
self.accumulated_transcript = ""
# Check every 0.5 seconds
await asyncio.sleep(0.5)
except Exception as e:
logger.error(f"Error in auto-final monitoring: {e}")
await asyncio.sleep(1)
# async def send_transcription(self, text, is_final=False, confidence=0.9):
# """Send transcription in jambonz format with Arabic number conversion"""
# try:
# # Convert Arabic numbers to digits before sending
# original_text = text
# converted_text = convert_arabic_numbers_in_sentence(text)
# # Log the conversion if numbers were found and converted
# if original_text != converted_text:
# logger.info(f"Arabic numbers converted: '{original_text}' -> '{converted_text}'")
# message = {
# "type": "transcription",
# "is_final": is_final,
# "alternatives": [
# {
# "transcript": converted_text,
# "confidence": confidence
# }
# ],
# "language": self.config.get("language", "ar-EG"),
# "channel": 1
# }
# await self.websocket.send(json.dumps(message))
# logger.info(f"Sent {'FINAL' if is_final else 'interim'} transcription: '{converted_text}'")
# if is_final:
# self.final_sent = True
# except Exception as e:
# logger.error(f"Error sending transcription: {e}")
async def send_transcription(self, text, is_final=False, confidence=0.9):
"""Send transcription in jambonz format with Arabic number conversion, only for final results"""
try:
if not is_final:
# Do nothing for interim results
logger.debug("Skipping interim transcription (not final).")
return
# Convert Arabic numbers only for final transcripts
original_text = text
converted_text = convert_arabic_numbers_in_sentence(text)
# Log the conversion if numbers were found and converted
if original_text != converted_text:
logger.info(f"Arabic numbers converted: '{original_text}' -> '{converted_text}'")
message = {
"type": "transcription",
"is_final": True,
"alternatives": [
{
"transcript": original_text,#converted_text,
"confidence": confidence
}
],
"language": self.config.get("language", "ar-EG"),
"channel": 1
}
# Send only final messages
await self.websocket.send(json.dumps(message))
logger.info(f"Sent FINAL transcription: '{converted_text}'")
self.final_sent = True
except Exception as e:
logger.error(f"Error sending transcription: {e}")
async def send_error(self, error_message):
"""Send error message in jambonz format"""
try:
message = {
"type": "error",
"error": error_message
}
await self.websocket.send(json.dumps(message))
logger.error(f"Sent error: {error_message}")
except Exception as e:
logger.error(f"Error sending error message: {e}")
async def handle_jambonz_websocket(websocket):
"""Handle jambonz WebSocket connections"""
client_id = f"jambonz_{id(websocket)}"
logger.info(f"New jambonz connection: {client_id}")
handler = JambonzSTTHandler(websocket)
try:
async for message in websocket:
try:
if isinstance(message, str):
# Handle JSON control messages
data = json.loads(message)
message_type = data.get("type")
if message_type == "start":
logger.info(f"Received start message: {data}")
await handler.start_processing(data)
elif message_type == "stop":
logger.info("Received stop message")
await handler.stop_processing()
# Close websocket after final transcription
await websocket.close(code=1000, reason="Session completed")
break
else:
logger.warning(f"Unknown message type: {message_type}")
await handler.send_error(f"Unknown message type: {message_type}")
else:
# Handle binary audio data (LINEAR16 PCM)
if handler.audio_buffer is None:
await handler.send_error("Received audio before start message")
continue
await handler.add_audio_data(message)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
await handler.send_error(f"Invalid JSON: {str(e)}")
except Exception as e:
logger.error(f"Error processing message: {e}")
await handler.send_error(f"Processing error: {str(e)}")
except websockets.exceptions.ConnectionClosed:
logger.info(f"jambonz connection closed: {client_id}")
except Exception as e:
logger.error(f"jambonz WebSocket error: {e}")
try:
await handler.send_error(str(e))
except:
pass
finally:
if handler.running:
await handler.stop_processing()
logger.info(f"jambonz connection ended: {client_id}")
async def main():
"""Start the jambonz STT WebSocket server"""
logger.info("Starting Jambonz Custom STT WebSocket server on port 3006...")
# Start WebSocket server
server = await websockets.serve(
handle_jambonz_websocket,
"0.0.0.0",
3006,
ping_interval=20,
ping_timeout=10,
close_timeout=10
)
logger.info("Jambonz Custom STT WebSocket server started on ws://0.0.0.0:3006")
logger.info("Ready to handle jambonz STT requests")
logger.info("- Expects LINEAR16 PCM audio at 8kHz")
logger.info("- Supports interim results with auto-final detection")
logger.info("- Auto-final: 3+ interim results + 1.3s silence")
logger.info("- Resamples to 16kHz for Whisper processing")
logger.info("- Converts Arabic numbers to digits before sending")
# Wait for the server to close
await server.wait_closed()
if __name__ == "__main__":
print("=" * 60)
print("Jambonz Custom STT Server with Whisper + Arabic Numbers")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"Device: {device}")
print("WebSocket Port: 3006")
print("Protocol: jambonz STT API")
print("Audio Format: LINEAR16 PCM @ 8kHz")
print("Auto-Final: 2+ speech activities + 1.3s silence")
print("Arabic Numbers: Converted to digits in FINAL transcriptions only")
print("Interim Results: DISABLED (final transcription only)")
if arabic_numbers_available:
print("✓ pyarabic library available for number conversion")
else:
print("✗ pyarabic library not available - install with: pip install pyarabic")
print("=" * 60)
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nShutting down server...")
except Exception as e:
print(f"Server error: {e}")