| import torch |
| import torchaudio |
| import numpy as np |
| from typing import Optional, Callable |
| from config.settings import settings |
|
|
| class SpeechBrainVAD: |
| def __init__(self): |
| self.model = None |
| self.sample_rate = settings.SAMPLE_RATE |
| self.is_streaming = False |
| self.speech_callback = None |
| self.audio_buffer = [] |
| self._initialize_model() |
|
|
| def _initialize_model(self): |
| """Khởi tạo VAD model từ SpeechBrain""" |
| try: |
| from speechbrain.pretrained import VAD |
| print("🔄 Đang tải VAD model từ SpeechBrain...") |
| self.model = VAD.from_hparams( |
| source=settings.VAD_MODEL, |
| savedir=f"/tmp/{settings.VAD_MODEL.replace('/', '_')}" |
| ) |
| print("✅ Đã tải VAD model thành công") |
| except Exception as e: |
| print(f"❌ Lỗi tải VAD model: {e}") |
| self.model = None |
|
|
| def start_stream(self, speech_callback: Callable): |
| """Bắt đầu stream với VAD""" |
| if self.model is None: |
| print("❌ VAD model chưa được khởi tạo") |
| return False |
| |
| self.is_streaming = True |
| self.speech_callback = speech_callback |
| self.audio_buffer = [] |
| print("🎙️ Bắt đầu VAD streaming...") |
| return True |
|
|
| def stop_stream(self): |
| """Dừng stream""" |
| self.is_streaming = False |
| self.speech_callback = None |
| self.audio_buffer = [] |
| print("🛑 Đã dừng VAD streaming") |
|
|
| def process_stream(self, audio_chunk: np.ndarray, sample_rate: int): |
| """Xử lý audio chunk với VAD""" |
| if not self.is_streaming or self.model is None: |
| return |
|
|
| try: |
| |
| if sample_rate != self.sample_rate: |
| audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate) |
|
|
| |
| self.audio_buffer.extend(audio_chunk) |
|
|
| |
| buffer_duration = len(self.audio_buffer) / self.sample_rate |
| if buffer_duration >= 2.0: |
| self._process_buffer() |
|
|
| except Exception as e: |
| print(f"❌ Lỗi xử lý VAD: {e}") |
|
|
| def _process_buffer(self): |
| """Xử lý buffer audio với VAD""" |
| try: |
| |
| audio_tensor = torch.FloatTensor(self.audio_buffer).unsqueeze(0) |
| |
| |
| boundaries = self.model.get_speech_segments( |
| audio_tensor, |
| |
| threshold=settings.VAD_THRESHOLD - 0.1, |
| min_silence_duration=settings.VAD_MIN_SILENCE_DURATION + 0.3, |
| speech_pad_duration=settings.VAD_SPEECH_PAD_DURATION |
| ) |
|
|
| |
| if len(boundaries) > 0: |
| for start, end in boundaries: |
| start_sample = int(start * self.sample_rate) |
| end_sample = int(end * self.sample_rate) |
| |
| |
| speech_audio = np.array(self.audio_buffer[start_sample:end_sample]) |
| |
| if len(speech_audio) > self.sample_rate * 0.5: |
| print(f"🎯 VAD phát hiện speech: {len(speech_audio)/self.sample_rate:.2f}s") |
| |
| |
| if self.speech_callback: |
| self.speech_callback(speech_audio, self.sample_rate) |
|
|
| |
| keep_samples = int(self.sample_rate * 0.5) |
| if len(self.audio_buffer) > keep_samples: |
| self.audio_buffer = self.audio_buffer[-keep_samples:] |
| else: |
| self.audio_buffer = [] |
|
|
| except Exception as e: |
| print(f"❌ Lỗi xử lý VAD buffer: {e}") |
| self.audio_buffer = [] |
|
|
| def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: |
| """Resample audio nếu cần""" |
| if orig_sr == target_sr: |
| return audio |
| |
| try: |
| audio_tensor = torch.FloatTensor(audio).unsqueeze(0) |
| resampler = torchaudio.transforms.Resample(orig_sr, target_sr) |
| resampled = resampler(audio_tensor) |
| return resampled.squeeze(0).numpy() |
| except Exception as e: |
| print(f"⚠️ Lỗi resample: {e}") |
| return audio |
|
|
| def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool: |
| """Kiểm tra xem audio chunk có phải là speech không""" |
| if self.model is None: |
| return True |
| |
| try: |
| |
| if sample_rate != self.sample_rate: |
| audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate) |
|
|
| |
| audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0) |
| |
| |
| prob_speech = self.model.get_speech_prob_chunk(audio_tensor) |
| |
| |
| return prob_speech.mean().item() > (settings.VAD_THRESHOLD - 0.1) |
| |
| except Exception as e: |
| print(f"❌ Lỗi kiểm tra speech: {e}") |
| return True |