| import gradio as gr |
| import numpy as np |
| import queue |
| import torch |
| import time |
| import threading |
| import os |
| import urllib.request |
| import torchaudio |
| from scipy.spatial.distance import cosine |
| from RealtimeSTT import AudioToTextRecorder |
| from fastapi import FastAPI, APIRouter |
| from fastrtc import Stream, AsyncStreamHandler, ReplyOnPause, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials |
| import json |
| import io |
| import wave |
| import asyncio |
| import uvicorn |
| import socket |
| from queue import Queue |
| import time |
| |
| SILENCE_THRESHS = [0, 0.4] |
| FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" |
| FINAL_BEAM_SIZE = 5 |
| REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" |
| REALTIME_BEAM_SIZE = 5 |
| TRANSCRIPTION_LANGUAGE = "en" |
| SILERO_SENSITIVITY = 0.4 |
| WEBRTC_SENSITIVITY = 3 |
| MIN_LENGTH_OF_RECORDING = 0.7 |
| PRE_RECORDING_BUFFER_DURATION = 0.35 |
|
|
| |
| DEFAULT_CHANGE_THRESHOLD = 0.7 |
| EMBEDDING_HISTORY_SIZE = 5 |
| MIN_SEGMENT_DURATION = 1.0 |
| DEFAULT_MAX_SPEAKERS = 4 |
| ABSOLUTE_MAX_SPEAKERS = 10 |
|
|
| |
| FAST_SENTENCE_END = True |
| SAMPLE_RATE = 16000 |
| BUFFER_SIZE = 512 |
| CHANNELS = 1 |
|
|
| |
| SPEAKER_COLORS = [ |
| "#FFFF00", |
| "#FF0000", |
| "#00FF00", |
| "#00FFFF", |
| "#FF00FF", |
| "#0000FF", |
| "#FF8000", |
| "#00FF80", |
| "#8000FF", |
| "#FFFFFF", |
| ] |
|
|
| SPEAKER_COLOR_NAMES = [ |
| "Yellow", "Red", "Green", "Cyan", "Magenta", |
| "Blue", "Orange", "Spring Green", "Purple", "White" |
| ] |
|
|
|
|
| class SpeechBrainEncoder: |
| """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" |
| def __init__(self, device="cpu"): |
| self.device = device |
| self.model = None |
| self.embedding_dim = 192 |
| self.model_loaded = False |
| self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") |
| os.makedirs(self.cache_dir, exist_ok=True) |
| |
| def _download_model(self): |
| """Download pre-trained SpeechBrain ECAPA-TDNN model if not present""" |
| model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt" |
| model_path = os.path.join(self.cache_dir, "embedding_model.ckpt") |
| |
| if not os.path.exists(model_path): |
| print(f"Downloading ECAPA-TDNN model to {model_path}...") |
| urllib.request.urlretrieve(model_url, model_path) |
| |
| return model_path |
| |
| def load_model(self): |
| """Load the ECAPA-TDNN model""" |
| try: |
| from speechbrain.pretrained import EncoderClassifier |
| |
| model_path = self._download_model() |
| |
| self.model = EncoderClassifier.from_hparams( |
| source="speechbrain/spkrec-ecapa-voxceleb", |
| savedir=self.cache_dir, |
| run_opts={"device": self.device} |
| ) |
| |
| self.model_loaded = True |
| return True |
| except Exception as e: |
| print(f"Error loading ECAPA-TDNN model: {e}") |
| return False |
| |
| def embed_utterance(self, audio, sr=16000): |
| """Extract speaker embedding from audio""" |
| if not self.model_loaded: |
| raise ValueError("Model not loaded. Call load_model() first.") |
| |
| try: |
| if isinstance(audio, np.ndarray): |
| waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) |
| else: |
| waveform = audio.unsqueeze(0) |
| |
| if sr != 16000: |
| waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) |
| |
| with torch.no_grad(): |
| embedding = self.model.encode_batch(waveform) |
| |
| return embedding.squeeze().cpu().numpy() |
| except Exception as e: |
| print(f"Error extracting embedding: {e}") |
| return np.zeros(self.embedding_dim) |
|
|
|
|
| class AudioProcessor: |
| """Processes audio data to extract speaker embeddings""" |
| def __init__(self, encoder): |
| self.encoder = encoder |
| |
| def extract_embedding(self, audio_int16): |
| try: |
| float_audio = audio_int16.astype(np.float32) / 32768.0 |
| |
| if np.abs(float_audio).max() > 1.0: |
| float_audio = float_audio / np.abs(float_audio).max() |
| |
| embedding = self.encoder.embed_utterance(float_audio) |
| |
| return embedding |
| except Exception as e: |
| print(f"Embedding extraction error: {e}") |
| return np.zeros(self.encoder.embedding_dim) |
|
|
|
|
| class SpeakerChangeDetector: |
| """Speaker change detector that supports a configurable number of speakers""" |
| def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): |
| self.embedding_dim = embedding_dim |
| self.change_threshold = change_threshold |
| self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
| self.current_speaker = 0 |
| self.previous_embeddings = [] |
| self.last_change_time = time.time() |
| self.mean_embeddings = [None] * self.max_speakers |
| self.speaker_embeddings = [[] for _ in range(self.max_speakers)] |
| self.last_similarity = 0.0 |
| self.active_speakers = set([0]) |
| |
| def set_max_speakers(self, max_speakers): |
| """Update the maximum number of speakers""" |
| new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
| |
| if new_max < self.max_speakers: |
| for speaker_id in list(self.active_speakers): |
| if speaker_id >= new_max: |
| self.active_speakers.discard(speaker_id) |
| |
| if self.current_speaker >= new_max: |
| self.current_speaker = 0 |
| |
| if new_max > self.max_speakers: |
| self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) |
| self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) |
| else: |
| self.mean_embeddings = self.mean_embeddings[:new_max] |
| self.speaker_embeddings = self.speaker_embeddings[:new_max] |
| |
| self.max_speakers = new_max |
| |
| def set_change_threshold(self, threshold): |
| """Update the threshold for detecting speaker changes""" |
| self.change_threshold = max(0.1, min(threshold, 0.99)) |
| |
| def add_embedding(self, embedding, timestamp=None): |
| """Add a new embedding and check if there's a speaker change""" |
| current_time = timestamp or time.time() |
| |
| if not self.previous_embeddings: |
| self.previous_embeddings.append(embedding) |
| self.speaker_embeddings[self.current_speaker].append(embedding) |
| if self.mean_embeddings[self.current_speaker] is None: |
| self.mean_embeddings[self.current_speaker] = embedding.copy() |
| return self.current_speaker, 1.0 |
| |
| current_mean = self.mean_embeddings[self.current_speaker] |
| if current_mean is not None: |
| similarity = 1.0 - cosine(embedding, current_mean) |
| else: |
| similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) |
| |
| self.last_similarity = similarity |
| |
| time_since_last_change = current_time - self.last_change_time |
| is_speaker_change = False |
| |
| if time_since_last_change >= MIN_SEGMENT_DURATION: |
| if similarity < self.change_threshold: |
| best_speaker = self.current_speaker |
| best_similarity = similarity |
| |
| for speaker_id in range(self.max_speakers): |
| if speaker_id == self.current_speaker: |
| continue |
| |
| speaker_mean = self.mean_embeddings[speaker_id] |
| |
| if speaker_mean is not None: |
| speaker_similarity = 1.0 - cosine(embedding, speaker_mean) |
| |
| if speaker_similarity > best_similarity: |
| best_similarity = speaker_similarity |
| best_speaker = speaker_id |
| |
| if best_speaker != self.current_speaker: |
| is_speaker_change = True |
| self.current_speaker = best_speaker |
| elif len(self.active_speakers) < self.max_speakers: |
| for new_id in range(self.max_speakers): |
| if new_id not in self.active_speakers: |
| is_speaker_change = True |
| self.current_speaker = new_id |
| self.active_speakers.add(new_id) |
| break |
| |
| if is_speaker_change: |
| self.last_change_time = current_time |
| |
| self.previous_embeddings.append(embedding) |
| if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: |
| self.previous_embeddings.pop(0) |
| |
| self.speaker_embeddings[self.current_speaker].append(embedding) |
| self.active_speakers.add(self.current_speaker) |
| |
| if len(self.speaker_embeddings[self.current_speaker]) > 30: |
| self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] |
| |
| if self.speaker_embeddings[self.current_speaker]: |
| self.mean_embeddings[self.current_speaker] = np.mean( |
| self.speaker_embeddings[self.current_speaker], axis=0 |
| ) |
| |
| return self.current_speaker, similarity |
| |
| def get_color_for_speaker(self, speaker_id): |
| """Return color for speaker ID""" |
| if 0 <= speaker_id < len(SPEAKER_COLORS): |
| return SPEAKER_COLORS[speaker_id] |
| return "#FFFFFF" |
| |
| def get_status_info(self): |
| """Return status information about the speaker change detector""" |
| speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)] |
| |
| return { |
| "current_speaker": self.current_speaker, |
| "speaker_counts": speaker_counts, |
| "active_speakers": len(self.active_speakers), |
| "max_speakers": self.max_speakers, |
| "last_similarity": self.last_similarity, |
| "threshold": self.change_threshold |
| } |
|
|
|
|
| class RealtimeSpeakerDiarization: |
| def __init__(self): |
| self.encoder = None |
| self.audio_processor = None |
| self.speaker_detector = None |
| self.recorder = None |
| self.sentence_queue = queue.Queue() |
| self.full_sentences = [] |
| self.sentence_speakers = [] |
| self.pending_sentences = [] |
| self.displayed_text = "" |
| self.last_realtime_text = "" |
| self.is_running = False |
| self.change_threshold = DEFAULT_CHANGE_THRESHOLD |
| self.max_speakers = DEFAULT_MAX_SPEAKERS |
| self.current_conversation = "" |
| self.audio_buffer = [] |
| |
| def initialize_models(self): |
| """Initialize the speaker encoder model""" |
| try: |
| device_str = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device_str}") |
| |
| self.encoder = SpeechBrainEncoder(device=device_str) |
| success = self.encoder.load_model() |
| |
| if success: |
| self.audio_processor = AudioProcessor(self.encoder) |
| self.speaker_detector = SpeakerChangeDetector( |
| embedding_dim=self.encoder.embedding_dim, |
| change_threshold=self.change_threshold, |
| max_speakers=self.max_speakers |
| ) |
| print("ECAPA-TDNN model loaded successfully!") |
| return True |
| else: |
| print("Failed to load ECAPA-TDNN model") |
| return False |
| except Exception as e: |
| print(f"Model initialization error: {e}") |
| return False |
| |
| def live_text_detected(self, text): |
| """Callback for real-time transcription updates""" |
| text = text.strip() |
| if text: |
| sentence_delimiters = '.?!。' |
| prob_sentence_end = ( |
| len(self.last_realtime_text) > 0 |
| and text[-1] in sentence_delimiters |
| and self.last_realtime_text[-1] in sentence_delimiters |
| ) |
|
|
| self.last_realtime_text = text |
|
|
| if prob_sentence_end and FAST_SENTENCE_END: |
| self.recorder.stop() |
| elif prob_sentence_end: |
| self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0] |
| else: |
| self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1] |
| |
| def process_final_text(self, text): |
| """Process final transcribed text with speaker embedding""" |
| text = text.strip() |
| if text: |
| try: |
| bytes_data = self.recorder.last_transcription_bytes |
| self.sentence_queue.put((text, bytes_data)) |
| self.pending_sentences.append(text) |
| except Exception as e: |
| print(f"Error processing final text: {e}") |
| |
| def process_sentence_queue(self): |
| """Process sentences in the queue for speaker detection""" |
| while self.is_running: |
| try: |
| text, bytes_data = self.sentence_queue.get(timeout=1) |
| |
| |
| audio_int16 = np.frombuffer(bytes_data, dtype=np.int16) |
| |
| |
| speaker_embedding = self.audio_processor.extract_embedding(audio_int16) |
| |
| |
| self.full_sentences.append((text, speaker_embedding)) |
| |
| |
| while len(self.sentence_speakers) < len(self.full_sentences) - 1: |
| self.sentence_speakers.append(0) |
| |
| |
| speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding) |
| self.sentence_speakers.append(speaker_id) |
| |
| |
| if text in self.pending_sentences: |
| self.pending_sentences.remove(text) |
| |
| |
| self.current_conversation = self.get_formatted_conversation() |
| |
| except queue.Empty: |
| continue |
| except Exception as e: |
| print(f"Error processing sentence: {e}") |
| |
| def start_recording(self): |
| """Start the recording and transcription process""" |
| if self.encoder is None: |
| return "Please initialize models first!" |
| |
| try: |
| |
| recorder_config = { |
| 'spinner': False, |
| 'use_microphone': False, |
| 'model': FINAL_TRANSCRIPTION_MODEL, |
| 'language': TRANSCRIPTION_LANGUAGE, |
| 'silero_sensitivity': SILERO_SENSITIVITY, |
| 'webrtc_sensitivity': WEBRTC_SENSITIVITY, |
| 'post_speech_silence_duration': SILENCE_THRESHS[1], |
| 'min_length_of_recording': MIN_LENGTH_OF_RECORDING, |
| 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION, |
| 'min_gap_between_recordings': 0, |
| 'enable_realtime_transcription': True, |
| 'realtime_processing_pause': 0, |
| 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL, |
| 'on_realtime_transcription_update': self.live_text_detected, |
| 'beam_size': FINAL_BEAM_SIZE, |
| 'beam_size_realtime': REALTIME_BEAM_SIZE, |
| 'buffer_size': BUFFER_SIZE, |
| 'sample_rate': SAMPLE_RATE, |
| } |
|
|
| self.recorder = AudioToTextRecorder(**recorder_config) |
| |
| |
| self.is_running = True |
| self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True) |
| self.sentence_thread.start() |
| |
| |
| self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True) |
| self.transcription_thread.start() |
| |
| return "Recording started successfully! FastRTC audio input ready." |
| |
| except Exception as e: |
| return f"Error starting recording: {e}" |
| |
| def run_transcription(self): |
| """Run the transcription loop""" |
| try: |
| while self.is_running: |
| self.recorder.text(self.process_final_text) |
| except Exception as e: |
| print(f"Transcription error: {e}") |
| |
| def stop_recording(self): |
| """Stop the recording process""" |
| self.is_running = False |
| if self.recorder: |
| self.recorder.stop() |
| return "Recording stopped!" |
| |
| def clear_conversation(self): |
| """Clear all conversation data""" |
| self.full_sentences = [] |
| self.sentence_speakers = [] |
| self.pending_sentences = [] |
| self.displayed_text = "" |
| self.last_realtime_text = "" |
| self.current_conversation = "Conversation cleared!" |
| |
| if self.speaker_detector: |
| self.speaker_detector = SpeakerChangeDetector( |
| embedding_dim=self.encoder.embedding_dim, |
| change_threshold=self.change_threshold, |
| max_speakers=self.max_speakers |
| ) |
| |
| return "Conversation cleared!" |
| |
| def update_settings(self, threshold, max_speakers): |
| """Update speaker detection settings""" |
| self.change_threshold = threshold |
| self.max_speakers = max_speakers |
| |
| if self.speaker_detector: |
| self.speaker_detector.set_change_threshold(threshold) |
| self.speaker_detector.set_max_speakers(max_speakers) |
| |
| return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}" |
| |
| def get_formatted_conversation(self): |
| """Get the formatted conversation with speaker colors""" |
| try: |
| sentences_with_style = [] |
| |
| |
| for i, sentence in enumerate(self.full_sentences): |
| sentence_text, _ = sentence |
| if i >= len(self.sentence_speakers): |
| color = "#FFFFFF" |
| speaker_name = "Unknown" |
| else: |
| speaker_id = self.sentence_speakers[i] |
| color = self.speaker_detector.get_color_for_speaker(speaker_id) |
| speaker_name = f"Speaker {speaker_id + 1}" |
| |
| sentences_with_style.append( |
| f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>') |
| |
| |
| for pending_sentence in self.pending_sentences: |
| sentences_with_style.append( |
| f'<span style="color:#60FFFF;"><b>Processing:</b> {pending_sentence}</span>') |
| |
| if sentences_with_style: |
| return "<br><br>".join(sentences_with_style) |
| else: |
| return "Waiting for speech input..." |
| |
| except Exception as e: |
| return f"Error formatting conversation: {e}" |
| |
| def get_status_info(self): |
| """Get current status information""" |
| if not self.speaker_detector: |
| return "Speaker detector not initialized" |
| |
| try: |
| status = self.speaker_detector.get_status_info() |
| |
| status_lines = [ |
| f"**Current Speaker:** {status['current_speaker'] + 1}", |
| f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}", |
| f"**Last Similarity:** {status['last_similarity']:.3f}", |
| f"**Change Threshold:** {status['threshold']:.2f}", |
| f"**Total Sentences:** {len(self.full_sentences)}", |
| "", |
| "**Speaker Segment Counts:**" |
| ] |
| |
| for i in range(status['max_speakers']): |
| color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}" |
| status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}") |
| |
| return "\n".join(status_lines) |
| |
| except Exception as e: |
| return f"Error getting status: {e}" |
|
|
| def feed_audio_data(self, audio_data): |
| """Feed audio data to the recorder""" |
| if not self.is_running or not self.recorder: |
| return |
| |
| try: |
| |
| if isinstance(audio_data, np.ndarray): |
| if audio_data.dtype != np.int16: |
| |
| if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: |
| audio_data = (audio_data * 32767).astype(np.int16) |
| else: |
| audio_data = audio_data.astype(np.int16) |
| |
| |
| audio_bytes = audio_data.tobytes() |
| else: |
| audio_bytes = audio_data |
| |
| |
| self.recorder.feed_audio(audio_bytes) |
| |
| except Exception as e: |
| print(f"Error feeding audio data: {e}") |
|
|
|
|
| |
|
|
| class DiarizationHandler(AsyncStreamHandler): |
| def __init__(self, diarization_system): |
| super().__init__() |
| self.diarization_system = diarization_system |
| self.audio_queue = Queue() |
| self.is_processing = False |
| self.sample_rate = 16000 |
| |
| def copy(self): |
| """Return a fresh handler for each new stream connection""" |
| return DiarizationHandler(self.diarization_system) |
| |
| async def emit(self): |
| """Not used in this implementation - we only receive audio""" |
| return None |
| |
| async def receive(self, frame): |
| """Receive audio data from FastRTC and process it""" |
| try: |
| if not self.diarization_system.is_running: |
| return |
| |
| |
| if hasattr(frame, 'data') and frame.data is not None: |
| audio_data = frame.data |
| elif hasattr(frame, 'audio') and frame.audio is not None: |
| audio_data = frame.audio |
| else: |
| audio_data = frame |
| |
| |
| if isinstance(audio_data, bytes): |
| |
| audio_array = np.frombuffer(audio_data, dtype=np.int16) |
| |
| audio_array = audio_array.astype(np.float32) / 32768.0 |
| elif isinstance(audio_data, (list, tuple)): |
| audio_array = np.array(audio_data, dtype=np.float32) |
| elif isinstance(audio_data, np.ndarray): |
| audio_array = audio_data.astype(np.float32) |
| else: |
| print(f"Unknown audio data type: {type(audio_data)}") |
| return |
| |
| |
| if len(audio_array.shape) > 1 and audio_array.shape[1] > 1: |
| audio_array = np.mean(audio_array, axis=1) |
| |
| |
| if len(audio_array.shape) > 1: |
| audio_array = audio_array.flatten() |
| |
| |
| sample_rate = getattr(frame, 'sample_rate', self.sample_rate) |
| |
| |
| await self.process_audio_async(audio_array, sample_rate) |
| |
| except Exception as e: |
| print(f"Error in FastRTC audio receive: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| async def process_audio_async(self, audio_data, sample_rate=16000): |
| """Process audio data asynchronously""" |
| try: |
| |
| loop = asyncio.get_event_loop() |
| await loop.run_in_executor( |
| None, |
| self.diarization_system.process_audio_chunk, |
| audio_data, |
| sample_rate |
| ) |
| except Exception as e: |
| print(f"Error in async audio processing: {e}") |
|
|
|
|
| |
| diarization_system = None |
| audio_handler = None |
|
|
|
|
| def initialize_system(): |
| """Initialize the diarization system""" |
| global audio_handler, diarization_system |
| try: |
| if diarization_system is None: |
| print("Error: RealtimeSpeakerDiarization not initialized") |
| return "❌ Diarization system not available. Please ensure RealtimeSpeakerDiarization is properly imported." |
| |
| success = diarization_system.initialize_models() |
| if success: |
| audio_handler = DiarizationHandler(diarization_system) |
| return "✅ System initialized successfully! Models loaded and FastRTC handler ready." |
| else: |
| return "❌ Failed to initialize system. Please check the logs." |
| except Exception as e: |
| print(f"Initialization error: {e}") |
| return f"❌ Initialization error: {str(e)}" |
|
|
|
|
| def start_recording(): |
| """Start recording and transcription""" |
| try: |
| if diarization_system is None: |
| return "❌ System not initialized" |
| result = diarization_system.start_recording() |
| return f"🎙️ {result} - FastRTC audio streaming is active." |
| except Exception as e: |
| return f"❌ Failed to start recording: {str(e)}" |
|
|
|
|
| def stop_recording(): |
| """Stop recording and transcription""" |
| try: |
| if diarization_system is None: |
| return "❌ System not initialized" |
| result = diarization_system.stop_recording() |
| return f"⏹️ {result}" |
| except Exception as e: |
| return f"❌ Failed to stop recording: {str(e)}" |
|
|
|
|
| def clear_conversation(): |
| """Clear the conversation""" |
| try: |
| if diarization_system is None: |
| return "❌ System not initialized" |
| result = diarization_system.clear_conversation() |
| return f"🗑️ {result}" |
| except Exception as e: |
| return f"❌ Failed to clear conversation: {str(e)}" |
|
|
|
|
| def update_settings(threshold, max_speakers): |
| """Update system settings""" |
| try: |
| if diarization_system is None: |
| return "❌ System not initialized" |
| result = diarization_system.update_settings(threshold, max_speakers) |
| return f"⚙️ {result}" |
| except Exception as e: |
| return f"❌ Failed to update settings: {str(e)}" |
|
|
|
|
| def get_conversation(): |
| """Get the current conversation""" |
| try: |
| if diarization_system is None: |
| return "<i>System not initialized</i>" |
| return diarization_system.get_formatted_conversation() |
| except Exception as e: |
| return f"<i>Error getting conversation: {str(e)}</i>" |
|
|
|
|
| def get_status(): |
| """Get system status""" |
| try: |
| if diarization_system is None: |
| return "System not initialized" |
| return diarization_system.get_status_info() |
| except Exception as e: |
| return f"Error getting status: {str(e)}" |
|
|
|
|
| |
| def create_interface(): |
| with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface: |
| gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization") |
| gr.Markdown("This app performs real-time speech recognition with automatic speaker identification using FastRTC for low-latency audio streaming.") |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| |
| conversation_output = gr.HTML( |
| value="<div style='padding: 20px; background: #f5f5f5; border-radius: 10px;'><i>Click 'Initialize System' to start...</i></div>", |
| label="Live Conversation", |
| elem_id="conversation_display" |
| ) |
| |
| |
| with gr.Row(): |
| init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg") |
| start_btn = gr.Button("🎙️ Start Recording", variant="primary", size="lg", interactive=False) |
| stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", size="lg", interactive=False) |
| clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False) |
| |
| |
| with gr.Row(): |
| gr.HTML(""" |
| <div id="fastrtc-container" style="border: 2px solid #ddd; border-radius: 10px; padding: 20px; margin: 10px 0;"> |
| <h3>🎵 Audio Stream</h3> |
| <p>FastRTC audio stream will appear here when recording starts.</p> |
| <div id="stream-status" style="padding: 10px; background: #f8f9fa; border-radius: 5px; margin-top: 10px;"> |
| Status: Waiting for initialization... |
| </div> |
| </div> |
| """) |
| |
| |
| status_output = gr.Textbox( |
| label="System Status", |
| value="System not initialized. Please click 'Initialize System' to begin.", |
| lines=6, |
| interactive=False, |
| show_copy_button=True |
| ) |
| |
| with gr.Column(scale=1): |
| |
| gr.Markdown("## ⚙️ Settings") |
| |
| threshold_slider = gr.Slider( |
| minimum=0.1, |
| maximum=0.95, |
| step=0.05, |
| value=0.5, |
| label="Speaker Change Sensitivity", |
| info="Lower = more sensitive to speaker changes" |
| ) |
| |
| max_speakers_slider = gr.Slider( |
| minimum=2, |
| maximum=10, |
| step=1, |
| value=4, |
| label="Maximum Number of Speakers" |
| ) |
| |
| update_settings_btn = gr.Button("Update Settings", variant="secondary") |
| |
| |
| gr.Markdown("## 🔊 Audio Configuration") |
| with gr.Accordion("Advanced Audio Settings", open=False): |
| gr.Markdown(""" |
| **Current Configuration:** |
| - Sample Rate: 16kHz |
| - Audio Format: 16-bit PCM → Float32 (via AudioProcessor) |
| - Channels: Mono (stereo converted automatically) |
| - Buffer Size: 1024 samples for real-time processing |
| - Processing: Uses existing AudioProcessor.extract_embedding() |
| """) |
| |
| |
| gr.Markdown("## 📝 How to Use") |
| gr.Markdown(""" |
| 1. **Initialize**: Click "Initialize System" to load AI models |
| 2. **Start**: Click "Start Recording" to begin processing |
| 3. **Connect**: The FastRTC stream will activate automatically |
| 4. **Allow Access**: Grant microphone permissions when prompted |
| 5. **Speak**: Talk naturally into your microphone |
| 6. **Monitor**: Watch real-time transcription with speaker colors |
| """) |
| |
| |
| with gr.Accordion("💡 Performance Tips", open=False): |
| gr.Markdown(""" |
| - Use Chrome/Edge for best FastRTC performance |
| - Ensure stable internet connection |
| - Use headphones to prevent echo |
| - Position microphone 6-12 inches away |
| - Minimize background noise |
| - Allow browser microphone access |
| """) |
| |
| |
| gr.Markdown("## 🎨 Speaker Colors") |
| speaker_colors = [ |
| ("#FF6B6B", "Red"), |
| ("#4ECDC4", "Teal"), |
| ("#45B7D1", "Blue"), |
| ("#96CEB4", "Green"), |
| ("#FFEAA7", "Yellow"), |
| ("#DDA0DD", "Plum"), |
| ("#98D8C8", "Mint"), |
| ("#F7DC6F", "Gold") |
| ] |
| |
| color_html = "" |
| for i, (color, name) in enumerate(speaker_colors[:4]): |
| color_html += f'<div style="margin: 3px 0;"><span style="color:{color}; font-size: 16px; font-weight: bold;">●</span> Speaker {i+1} ({name})</div>' |
| |
| gr.HTML(f"<div style='font-size: 14px;'>{color_html}</div>") |
| |
| |
| def refresh_display(): |
| try: |
| conversation = get_conversation() |
| status = get_status() |
| return conversation, status |
| except Exception as e: |
| error_msg = f"Error refreshing display: {str(e)}" |
| return f"<i>{error_msg}</i>", error_msg |
| |
| |
| def on_initialize(): |
| try: |
| result = initialize_system() |
| success = "successfully" in result.lower() |
| |
| conversation, status = refresh_display() |
| |
| return ( |
| result, |
| gr.update(interactive=success), |
| gr.update(interactive=success), |
| conversation, |
| ) |
| except Exception as e: |
| error_msg = f"❌ Initialization failed: {str(e)}" |
| return ( |
| error_msg, |
| gr.update(interactive=False), |
| gr.update(interactive=False), |
| "<i>System not ready</i>", |
| ) |
| |
| def on_start(): |
| try: |
| result = start_recording() |
| return ( |
| result, |
| gr.update(interactive=False), |
| gr.update(interactive=True), |
| ) |
| except Exception as e: |
| error_msg = f"❌ Failed to start: {str(e)}" |
| return ( |
| error_msg, |
| gr.update(interactive=True), |
| gr.update(interactive=False), |
| ) |
| |
| def on_stop(): |
| try: |
| result = stop_recording() |
| return ( |
| result, |
| gr.update(interactive=True), |
| gr.update(interactive=False), |
| ) |
| except Exception as e: |
| error_msg = f"❌ Failed to stop: {str(e)}" |
| return ( |
| error_msg, |
| gr.update(interactive=False), |
| gr.update(interactive=True), |
| ) |
| |
| def on_clear(): |
| try: |
| result = clear_conversation() |
| conversation, status = refresh_display() |
| return result, conversation |
| except Exception as e: |
| error_msg = f"❌ Failed to clear: {str(e)}" |
| return error_msg, "<i>Error clearing conversation</i>" |
| |
| def on_update_settings(threshold, max_speakers): |
| try: |
| result = update_settings(threshold, max_speakers) |
| return result |
| except Exception as e: |
| return f"❌ Failed to update settings: {str(e)}" |
| |
| |
| init_btn.click( |
| on_initialize, |
| outputs=[status_output, start_btn, clear_btn, conversation_output] |
| ) |
| |
| start_btn.click( |
| on_start, |
| outputs=[status_output, start_btn, stop_btn] |
| ) |
| |
| stop_btn.click( |
| on_stop, |
| outputs=[status_output, start_btn, stop_btn] |
| ) |
| |
| clear_btn.click( |
| on_clear, |
| outputs=[status_output, conversation_output] |
| ) |
| |
| update_settings_btn.click( |
| on_update_settings, |
| inputs=[threshold_slider, max_speakers_slider], |
| outputs=[status_output] |
| ) |
| |
| |
| refresh_timer = gr.Timer(2.0) |
| refresh_timer.tick( |
| refresh_display, |
| outputs=[conversation_output, status_output] |
| ) |
| |
| return interface |
|
|
|
|
| |
| def create_fastapi_app(): |
| """Create FastAPI app with API endpoints""" |
| app = FastAPI( |
| title="Real-time Speaker Diarization", |
| description="Real-time speech recognition with speaker diarization using FastRTC", |
| version="1.0.0" |
| ) |
| |
| |
| router = APIRouter() |
| |
| @router.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return { |
| "status": "healthy", |
| "timestamp": time.time(), |
| "system_initialized": diarization_system is not None and hasattr(diarization_system, 'encoder') and diarization_system.encoder is not None, |
| "recording_active": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False |
| } |
| |
| @router.get("/api/conversation") |
| async def get_conversation_api(): |
| """Get current conversation""" |
| try: |
| return { |
| "conversation": get_conversation(), |
| "status": get_status(), |
| "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False, |
| "timestamp": time.time() |
| } |
| except Exception as e: |
| return {"error": str(e), "timestamp": time.time()} |
| |
| @router.post("/api/control/{action}") |
| async def control_recording(action: str): |
| """Control recording actions""" |
| try: |
| if action == "start": |
| result = start_recording() |
| elif action == "stop": |
| result = stop_recording() |
| elif action == "clear": |
| result = clear_conversation() |
| elif action == "initialize": |
| result = initialize_system() |
| else: |
| return {"error": "Invalid action. Use: start, stop, clear, or initialize"} |
| |
| return { |
| "result": result, |
| "is_recording": diarization_system.is_running if diarization_system and hasattr(diarization_system, 'is_running') else False, |
| "timestamp": time.time() |
| } |
| except Exception as e: |
| return {"error": str(e), "timestamp": time.time()} |
| |
| app.include_router(router) |
| return app |
|
|
|
|
| |
| def setup_fastrtc_stream(app): |
| """Setup FastRTC stream with proper configuration""" |
| try: |
| if audio_handler is None: |
| print("Warning: Audio handler not initialized. Initialize system first.") |
| return None |
| |
| |
| hf_token = os.environ.get("HF_TOKEN") |
| |
| |
| rtc_config = { |
| "iceServers": [ |
| {"urls": "stun:stun.l.google.com:19302"}, |
| {"urls": "stun:stun1.l.google.com:19302"} |
| ] |
| } |
| |
| |
| stream = Stream( |
| handler=audio_handler, |
| rtc_configuration=rtc_config, |
| modality="audio", |
| mode="receive" |
| ) |
| |
| |
| app.mount("/stream", stream) |
| print("✅ FastRTC stream configured successfully!") |
| return stream |
| |
| except Exception as e: |
| print(f"⚠️ Warning: Failed to setup FastRTC stream: {e}") |
| print("Audio streaming may not work properly.") |
| return None |
|
|
|
|
| |
| def create_app(diarization_sys=None): |
| """Create the complete application""" |
| global diarization_system |
| |
| |
| if diarization_sys is not None: |
| diarization_system = diarization_sys |
| |
| |
| fastapi_app = create_fastapi_app() |
| |
| |
| gradio_interface = create_interface() |
| |
| |
| app = gr.mount_gradio_app(fastapi_app, gradio_interface, path="/") |
| |
| |
| |
| |
| return app, gradio_interface |
|
|
|
|
| |
| if __name__ == "__main__": |
| try: |
| |
| |
| |
| |
| |
| app, interface = create_app() |
| |
| |
| interface.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=True, |
| show_error=True, |
| quiet=False |
| ) |
| |
| except Exception as e: |
| print(f"Failed to launch application: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| try: |
| interface = create_interface() |
| interface.launch( |
| server_name="0.0.0.0", |
| server_port=int(os.environ.get("PORT", 7860)), |
| share=False |
| ) |
| except Exception as fallback_error: |
| print(f"Fallback launch also failed: {fallback_error}") |
|
|
|
|
| |
| def initialize_with_diarization_system(diarization_sys): |
| """Initialize the application with your diarization system""" |
| global diarization_system |
| diarization_system = diarization_sys |
| return create_app(diarization_sys) |