| |
| |
| |
| |
| |
| |
| from ctypes import c_int, c_int32, c_float, c_size_t, CDLL, c_void_p, POINTER |
| import numpy as np |
| import os |
| import platform |
|
|
| class TenVad: |
| def __init__(self, hop_size: int = 256, threshold: float = 0.5): |
| self.hop_size = hop_size |
| self.threshold = threshold |
| if platform.system() == "Linux" and platform.machine() == "x86_64": |
| git_path = os.path.join( |
| os.path.dirname(os.path.relpath(__file__)), |
| "../lib/Linux/x64/libten_vad.so" |
| ) |
| if os.path.exists(git_path): |
| self.vad_library = CDLL(git_path) |
| else: |
| pip_path = os.path.join( |
| os.path.dirname(os.path.relpath(__file__)), |
| "./ten_vad_library/libten_vad.so" |
| ) |
| self.vad_library = CDLL(pip_path) |
| |
| elif platform.system() == "Darwin": |
| git_path = os.path.join( |
| os.path.dirname(os.path.relpath(__file__)), |
| "../lib/macOS/ten_vad.framework/Versions/A/ten_vad" |
| ) |
| if os.path.exists(git_path): |
| self.vad_library = CDLL(git_path) |
| else: |
| pip_path = os.path.join( |
| os.path.dirname(os.path.relpath(__file__)), |
| "./ten_vad_library/libten_vad" |
| ) |
| self.vad_library = CDLL(pip_path) |
| elif platform.system().upper() == 'WINDOWS': |
| if platform.machine().upper() in ['X64', 'X86_64', 'AMD64']: |
| git_path = os.path.join( |
| os.path.dirname(os.path.realpath(__file__)), |
| "../lib/Windows/x64/ten_vad.dll" |
| ) |
| if os.path.exists(git_path): |
| self.vad_library = CDLL(git_path) |
| else: |
| pip_path = os.path.join( |
| os.path.dirname(os.path.realpath(__file__)), |
| "./ten_vad_library/ten_vad.dll" |
| ) |
| self.vad_library = CDLL(pip_path) |
| else: |
| git_path = os.path.join( |
| os.path.dirname(os.path.realpath(__file__)), |
| "../lib/Windows/x86/ten_vad.dll" |
| ) |
| if os.path.exists(git_path): |
| self.vad_library = CDLL(git_path) |
| else: |
| pip_path = os.path.join( |
| os.path.dirname(os.path.realpath(__file__)), |
| "./ten_vad_library/ten_vad.dll" |
| ) |
| self.vad_library = CDLL(pip_path) |
| else: |
| raise NotImplementedError(f"Unsupported platform: {platform.system()} {platform.machine()}") |
| self.vad_handler = c_void_p(0) |
| self.out_probability = c_float() |
| self.out_flags = c_int32() |
|
|
| self.vad_library.ten_vad_create.argtypes = [ |
| POINTER(c_void_p), |
| c_size_t, |
| c_float, |
| ] |
| self.vad_library.ten_vad_create.restype = c_int |
|
|
| self.vad_library.ten_vad_destroy.argtypes = [POINTER(c_void_p)] |
| self.vad_library.ten_vad_destroy.restype = c_int |
|
|
| self.vad_library.ten_vad_process.argtypes = [ |
| c_void_p, |
| c_void_p, |
| c_size_t, |
| POINTER(c_float), |
| POINTER(c_int32), |
| ] |
| self.vad_library.ten_vad_process.restype = c_int |
| self.create_and_init_handler() |
| |
| def create_and_init_handler(self): |
| assert ( |
| self.vad_library.ten_vad_create( |
| POINTER(c_void_p)(self.vad_handler), |
| c_size_t(self.hop_size), |
| c_float(self.threshold), |
| ) |
| == 0 |
| ), "[TEN VAD]: create handler failure!" |
|
|
| def __del__(self): |
| assert ( |
| self.vad_library.ten_vad_destroy( |
| POINTER(c_void_p)(self.vad_handler) |
| ) |
| == 0 |
| ), "[TEN VAD]: destroy handler failure!" |
| |
| def get_input_data(self, audio_data: np.ndarray): |
| audio_data = np.squeeze(audio_data) |
| assert ( |
| len(audio_data.shape) == 1 |
| and audio_data.shape[0] == self.hop_size |
| ), "[TEN VAD]: audio data shape should be [%d]" % ( |
| self.hop_size |
| ) |
| assert ( |
| type(audio_data[0]) == np.int16 |
| ), "[TEN VAD]: audio data type error, must be int16" |
| data_pointer = audio_data.__array_interface__["data"][0] |
| return c_void_p(data_pointer) |
| |
| def process(self, audio_data: np.ndarray): |
| input_pointer = self.get_input_data(audio_data) |
| self.vad_library.ten_vad_process( |
| self.vad_handler, |
| input_pointer, |
| c_size_t(self.hop_size), |
| POINTER(c_float)(self.out_probability), |
| POINTER(c_int32)(self.out_flags), |
| ) |
| return self.out_probability.value, self.out_flags.value |
|
|
|
|
|
|