| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| import argparse |
| import logging |
| import math |
| import os |
| import time |
| import warnings |
| from enum import Enum |
| from pathlib import Path |
| from typing import Any, Dict, List, Tuple, Union |
|
|
| import kaldi_native_fbank as knf |
| import numpy as np |
| import sentencepiece as spm |
| import soundfile as sf |
| import yaml |
| from onnxruntime import (GraphOptimizationLevel, InferenceSession, |
| SessionOptions, get_available_providers, get_device) |
| from rknnlite.api.rknn_lite import RKNNLite |
|
|
| RKNN_INPUT_LEN = 171 |
|
|
| SPEECH_SCALE = 1 |
|
|
| class VadOrtInferRuntimeSession: |
| def __init__(self, config, root_dir: Path): |
| sess_opt = SessionOptions() |
| sess_opt.log_severity_level = 4 |
| sess_opt.enable_cpu_mem_arena = False |
| sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
| cuda_ep = "CUDAExecutionProvider" |
| cpu_ep = "CPUExecutionProvider" |
| cpu_provider_options = { |
| "arena_extend_strategy": "kSameAsRequested", |
| } |
|
|
| EP_list = [] |
| if ( |
| config["use_cuda"] |
| and get_device() == "GPU" |
| and cuda_ep in get_available_providers() |
| ): |
| EP_list = [(cuda_ep, config[cuda_ep])] |
| EP_list.append((cpu_ep, cpu_provider_options)) |
|
|
| config["model_path"] = root_dir / str(config["model_path"]) |
| self._verify_model(config["model_path"]) |
| logging.info(f"Loading onnx model at {str(config['model_path'])}") |
| self.session = InferenceSession( |
| str(config["model_path"]), sess_options=sess_opt, providers=EP_list |
| ) |
|
|
| if config["use_cuda"] and cuda_ep not in self.session.get_providers(): |
| logging.warning( |
| f"{cuda_ep} is not available for current env, " |
| f"the inference part is automatically shifted to be " |
| f"executed under {cpu_ep}.\n " |
| "Please ensure the installed onnxruntime-gpu version" |
| " matches your cuda and cudnn version, " |
| "you can check their relations from the offical web site: " |
| "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", |
| RuntimeWarning, |
| ) |
|
|
| def __call__( |
| self, input_content |
| ) -> np.ndarray: |
| if isinstance(input_content, list): |
| input_dict = { |
| "speech": input_content[0], |
| "in_cache0": input_content[1], |
| "in_cache1": input_content[2], |
| "in_cache2": input_content[3], |
| "in_cache3": input_content[4], |
| } |
| else: |
| input_dict = {"speech": input_content} |
|
|
| return self.session.run(None, input_dict) |
|
|
| def get_input_names( |
| self, |
| ): |
| return [v.name for v in self.session.get_inputs()] |
|
|
| def get_output_names( |
| self, |
| ): |
| return [v.name for v in self.session.get_outputs()] |
|
|
| def get_character_list(self, key = "character"): |
| return self.meta_dict[key].splitlines() |
|
|
| def have_key(self, key = "character") -> bool: |
| self.meta_dict = self.session.get_modelmeta().custom_metadata_map |
| if key in self.meta_dict.keys(): |
| return True |
| return False |
|
|
| @staticmethod |
| def _verify_model(model_path): |
| model_path = Path(model_path) |
| if not model_path.exists(): |
| raise FileNotFoundError(f"{model_path} does not exists.") |
| if not model_path.is_file(): |
| raise FileExistsError(f"{model_path} is not a file.") |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
| logging.basicConfig(format=formatter, level=logging.INFO) |
|
|
|
|
| class OrtInferRuntimeSession: |
| def __init__(self, model_file, device_id=-1, intra_op_num_threads=4): |
| device_id = str(device_id) |
| sess_opt = SessionOptions() |
| sess_opt.intra_op_num_threads = intra_op_num_threads |
| sess_opt.log_severity_level = 4 |
| sess_opt.enable_cpu_mem_arena = False |
| sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
| cuda_ep = "CUDAExecutionProvider" |
| cuda_provider_options = { |
| "device_id": device_id, |
| "arena_extend_strategy": "kNextPowerOfTwo", |
| "cudnn_conv_algo_search": "EXHAUSTIVE", |
| "do_copy_in_default_stream": "true", |
| } |
| cpu_ep = "CPUExecutionProvider" |
| cpu_provider_options = { |
| "arena_extend_strategy": "kSameAsRequested", |
| } |
|
|
| EP_list = [] |
| if ( |
| device_id != "-1" |
| and get_device() == "GPU" |
| and cuda_ep in get_available_providers() |
| ): |
| EP_list = [(cuda_ep, cuda_provider_options)] |
| EP_list.append((cpu_ep, cpu_provider_options)) |
|
|
| self._verify_model(model_file) |
|
|
| self.session = InferenceSession( |
| model_file, sess_options=sess_opt, providers=EP_list |
| ) |
|
|
| |
| del model_file |
|
|
| if device_id != "-1" and cuda_ep not in self.session.get_providers(): |
| warnings.warn( |
| f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" |
| "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " |
| "you can check their relations from the offical web site: " |
| "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", |
| RuntimeWarning, |
| ) |
|
|
| def __call__(self, input_content) -> np.ndarray: |
| input_dict = dict(zip(self.get_input_names(), input_content)) |
| try: |
| result = self.session.run(self.get_output_names(), input_dict) |
| return result |
| except Exception as e: |
| print(e) |
| raise RuntimeError(f"ONNXRuntime inferece failed. ") from e |
|
|
| def get_input_names( |
| self, |
| ): |
| return [v.name for v in self.session.get_inputs()] |
|
|
| def get_output_names( |
| self, |
| ): |
| return [v.name for v in self.session.get_outputs()] |
|
|
| def get_character_list(self, key = "character"): |
| return self.meta_dict[key].splitlines() |
|
|
| def have_key(self, key = "character") -> bool: |
| self.meta_dict = self.session.get_modelmeta().custom_metadata_map |
| if key in self.meta_dict.keys(): |
| return True |
| return False |
|
|
| @staticmethod |
| def _verify_model(model_path): |
| model_path = Path(model_path) |
| if not model_path.exists(): |
| raise FileNotFoundError(f"{model_path} does not exists.") |
| if not model_path.is_file(): |
| raise FileExistsError(f"{model_path} is not a file.") |
|
|
|
|
| def log_softmax(x: np.ndarray) -> np.ndarray: |
| |
| x_max = np.max(x, axis=-1, keepdims=True) |
| |
| softmax = np.exp(x - x_max) |
| softmax_sum = np.sum(softmax, axis=-1, keepdims=True) |
| softmax = softmax / softmax_sum |
| |
| return np.log(softmax) |
|
|
|
|
| class SenseVoiceInferenceSession: |
| def __init__( |
| self, |
| embedding_model_file, |
| encoder_model_file, |
| bpe_model_file, |
| device_id=-1, |
| intra_op_num_threads=4, |
| ): |
| logging.info(f"Loading model from {embedding_model_file}") |
|
|
| self.embedding = np.load(embedding_model_file) |
| logging.info(f"Loading model {encoder_model_file}") |
| start = time.time() |
| self.encoder = RKNNLite(verbose=False) |
| self.encoder.load_rknn(encoder_model_file) |
| self.encoder.init_runtime() |
|
|
| logging.info( |
| f"Loading {encoder_model_file} takes {time.time() - start:.2f} seconds" |
| ) |
| self.blank_id = 0 |
| self.sp = spm.SentencePieceProcessor() |
| self.sp.load(bpe_model_file) |
|
|
| def __call__(self, speech, language, use_itn: bool) -> np.ndarray: |
| language_query = self.embedding[[[language]]] |
|
|
| |
| text_norm_query = self.embedding[[[14 if use_itn else 15]]] |
| event_emo_query = self.embedding[[[1, 2]]] |
|
|
| |
| speech = speech * SPEECH_SCALE |
| |
| input_content = np.concatenate( |
| [ |
| language_query, |
| event_emo_query, |
| text_norm_query, |
| speech, |
| ], |
| axis=1, |
| ).astype(np.float32) |
| print(input_content.shape) |
| |
| input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0))) |
| print("padded shape:", input_content.shape) |
| start_time = time.time() |
| np.save("input_content.npy",input_content) |
| encoder_out = self.encoder.inference(inputs=[input_content])[0] |
| end_time = time.time() |
| print(f"encoder inference time: {end_time - start_time:.2f} seconds") |
| |
| def unique_consecutive(arr): |
| if len(arr) == 0: |
| return arr |
| |
| mask = np.append([True], arr[1:] != arr[:-1]) |
| out = arr[mask] |
| out = out[out != self.blank_id] |
| return out.tolist() |
| |
| |
| |
| hypos = unique_consecutive(encoder_out[0].argmax(axis=0)) |
| text = self.sp.DecodeIds(hypos) |
| return text |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class WavFrontend: |
| """Conventional frontend structure for ASR.""" |
|
|
| def __init__( |
| self, |
| cmvn_file = None, |
| fs = 16000, |
| window = "hamming", |
| n_mels = 80, |
| frame_length = 25, |
| frame_shift = 10, |
| lfr_m = 7, |
| lfr_n = 6, |
| dither: float = 0, |
| **kwargs, |
| ) -> None: |
| opts = knf.FbankOptions() |
| opts.frame_opts.samp_freq = fs |
| opts.frame_opts.dither = dither |
| opts.frame_opts.window_type = window |
| opts.frame_opts.frame_shift_ms = float(frame_shift) |
| opts.frame_opts.frame_length_ms = float(frame_length) |
| opts.mel_opts.num_bins = n_mels |
| opts.energy_floor = 0 |
| opts.frame_opts.snip_edges = True |
| opts.mel_opts.debug_mel = False |
| self.opts = opts |
|
|
| self.lfr_m = lfr_m |
| self.lfr_n = lfr_n |
| self.cmvn_file = cmvn_file |
|
|
| if self.cmvn_file: |
| self.cmvn = self.load_cmvn() |
| self.fbank_fn = None |
| self.fbank_beg_idx = 0 |
| self.reset_status() |
|
|
| def reset_status(self): |
| self.fbank_fn = knf.OnlineFbank(self.opts) |
| self.fbank_beg_idx = 0 |
|
|
| def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| waveform = waveform * (1 << 15) |
| self.fbank_fn = knf.OnlineFbank(self.opts) |
| self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) |
| frames = self.fbank_fn.num_frames_ready |
| mat = np.empty([frames, self.opts.mel_opts.num_bins]) |
| for i in range(frames): |
| mat[i, :] = self.fbank_fn.get_frame(i) |
| feat = mat.astype(np.float32) |
| feat_len = np.array(mat.shape[0]).astype(np.int32) |
| return feat, feat_len |
|
|
| def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| if self.lfr_m != 1 or self.lfr_n != 1: |
| feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n) |
|
|
| if self.cmvn_file: |
| feat = self.apply_cmvn(feat) |
|
|
| feat_len = np.array(feat.shape[0]).astype(np.int32) |
| return feat, feat_len |
|
|
| def load_audio(self, filename) -> Tuple[np.ndarray, int]: |
| data, sample_rate = sf.read( |
| filename, |
| always_2d=True, |
| dtype="float32", |
| ) |
| assert ( |
| sample_rate == 16000 |
| ), f"Only 16000 Hz is supported, but got {sample_rate}Hz" |
| self.sample_rate = sample_rate |
| data = data[:, 0] |
| samples = np.ascontiguousarray(data) |
|
|
| return samples, sample_rate |
|
|
| @staticmethod |
| def apply_lfr(inputs: np.ndarray, lfr_m, lfr_n) -> np.ndarray: |
| LFR_inputs = [] |
|
|
| T = inputs.shape[0] |
| T_lfr = int(np.ceil(T / lfr_n)) |
| left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1)) |
| inputs = np.vstack((left_padding, inputs)) |
| T = T + (lfr_m - 1) // 2 |
| for i in range(T_lfr): |
| if lfr_m <= T - i * lfr_n: |
| LFR_inputs.append( |
| (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1) |
| ) |
| else: |
| |
| num_padding = lfr_m - (T - i * lfr_n) |
| frame = inputs[i * lfr_n :].reshape(-1) |
| for _ in range(num_padding): |
| frame = np.hstack((frame, inputs[-1])) |
|
|
| LFR_inputs.append(frame) |
| LFR_outputs = np.vstack(LFR_inputs).astype(np.float32) |
| return LFR_outputs |
|
|
| def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray: |
| """ |
| Apply CMVN with mvn data |
| """ |
| frame, dim = inputs.shape |
| means = np.tile(self.cmvn[0:1, :dim], (frame, 1)) |
| vars = np.tile(self.cmvn[1:2, :dim], (frame, 1)) |
| inputs = (inputs + means) * vars |
| return inputs |
|
|
| def get_features(self, inputs: Union[str, np.ndarray]): |
| if isinstance(inputs, str): |
| inputs, _ = self.load_audio(inputs) |
|
|
| fbank, _ = self.fbank(inputs) |
| feats = self.apply_cmvn(self.apply_lfr(fbank, self.lfr_m, self.lfr_n)) |
| return feats |
|
|
| def load_cmvn( |
| self, |
| ) -> np.ndarray: |
| with open(self.cmvn_file, "r", encoding="utf-8") as f: |
| lines = f.readlines() |
|
|
| means_list = [] |
| vars_list = [] |
| for i in range(len(lines)): |
| line_item = lines[i].split() |
| if line_item[0] == "<AddShift>": |
| line_item = lines[i + 1].split() |
| if line_item[0] == "<LearnRateCoef>": |
| add_shift_line = line_item[3 : (len(line_item) - 1)] |
| means_list = list(add_shift_line) |
| continue |
| elif line_item[0] == "<Rescale>": |
| line_item = lines[i + 1].split() |
| if line_item[0] == "<LearnRateCoef>": |
| rescale_line = line_item[3 : (len(line_item) - 1)] |
| vars_list = list(rescale_line) |
| continue |
|
|
| means = np.array(means_list).astype(np.float64) |
| vars = np.array(vars_list).astype(np.float64) |
| cmvn = np.array([means, vars]) |
| return cmvn |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| if not Path(yaml_path).exists(): |
| raise FileExistsError(f"The {yaml_path} does not exist.") |
|
|
| with open(str(yaml_path), "rb") as f: |
| data = yaml.load(f, Loader=yaml.Loader) |
| return data |
|
|
|
|
| class VadStateMachine(Enum): |
| kVadInStateStartPointNotDetected = 1 |
| kVadInStateInSpeechSegment = 2 |
| kVadInStateEndPointDetected = 3 |
|
|
|
|
| class FrameState(Enum): |
| kFrameStateInvalid = -1 |
| kFrameStateSpeech = 1 |
| kFrameStateSil = 0 |
|
|
|
|
| |
| class AudioChangeState(Enum): |
| kChangeStateSpeech2Speech = 0 |
| kChangeStateSpeech2Sil = 1 |
| kChangeStateSil2Sil = 2 |
| kChangeStateSil2Speech = 3 |
| kChangeStateNoBegin = 4 |
| kChangeStateInvalid = 5 |
|
|
|
|
| class VadDetectMode(Enum): |
| kVadSingleUtteranceDetectMode = 0 |
| kVadMutipleUtteranceDetectMode = 1 |
|
|
|
|
| class VADXOptions: |
| def __init__( |
| self, |
| sample_rate = 16000, |
| detect_mode = VadDetectMode.kVadMutipleUtteranceDetectMode.value, |
| snr_mode = 0, |
| max_end_silence_time = 800, |
| max_start_silence_time = 3000, |
| do_start_point_detection: bool = True, |
| do_end_point_detection: bool = True, |
| window_size_ms = 200, |
| sil_to_speech_time_thres = 150, |
| speech_to_sil_time_thres = 150, |
| speech_2_noise_ratio: float = 1.0, |
| do_extend = 1, |
| lookback_time_start_point = 200, |
| lookahead_time_end_point = 100, |
| max_single_segment_time = 60000, |
| nn_eval_block_size = 8, |
| dcd_block_size = 4, |
| snr_thres = -100.0, |
| noise_frame_num_used_for_snr = 100, |
| decibel_thres = -100.0, |
| speech_noise_thres: float = 0.6, |
| fe_prior_thres: float = 1e-4, |
| silence_pdf_num = 1, |
| sil_pdf_ids: List[int] = [0], |
| speech_noise_thresh_low: float = -0.1, |
| speech_noise_thresh_high: float = 0.3, |
| output_frame_probs: bool = False, |
| frame_in_ms = 10, |
| frame_length_ms = 25, |
| ): |
| self.sample_rate = sample_rate |
| self.detect_mode = detect_mode |
| self.snr_mode = snr_mode |
| self.max_end_silence_time = max_end_silence_time |
| self.max_start_silence_time = max_start_silence_time |
| self.do_start_point_detection = do_start_point_detection |
| self.do_end_point_detection = do_end_point_detection |
| self.window_size_ms = window_size_ms |
| self.sil_to_speech_time_thres = sil_to_speech_time_thres |
| self.speech_to_sil_time_thres = speech_to_sil_time_thres |
| self.speech_2_noise_ratio = speech_2_noise_ratio |
| self.do_extend = do_extend |
| self.lookback_time_start_point = lookback_time_start_point |
| self.lookahead_time_end_point = lookahead_time_end_point |
| self.max_single_segment_time = max_single_segment_time |
| self.nn_eval_block_size = nn_eval_block_size |
| self.dcd_block_size = dcd_block_size |
| self.snr_thres = snr_thres |
| self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr |
| self.decibel_thres = decibel_thres |
| self.speech_noise_thres = speech_noise_thres |
| self.fe_prior_thres = fe_prior_thres |
| self.silence_pdf_num = silence_pdf_num |
| self.sil_pdf_ids = sil_pdf_ids |
| self.speech_noise_thresh_low = speech_noise_thresh_low |
| self.speech_noise_thresh_high = speech_noise_thresh_high |
| self.output_frame_probs = output_frame_probs |
| self.frame_in_ms = frame_in_ms |
| self.frame_length_ms = frame_length_ms |
|
|
|
|
| class E2EVadSpeechBufWithDoa(object): |
| def __init__(self): |
| self.start_ms = 0 |
| self.end_ms = 0 |
| self.buffer = [] |
| self.contain_seg_start_point = False |
| self.contain_seg_end_point = False |
| self.doa = 0 |
|
|
| def reset(self): |
| self.start_ms = 0 |
| self.end_ms = 0 |
| self.buffer = [] |
| self.contain_seg_start_point = False |
| self.contain_seg_end_point = False |
| self.doa = 0 |
|
|
|
|
| class E2EVadFrameProb(object): |
| def __init__(self): |
| self.noise_prob = 0.0 |
| self.speech_prob = 0.0 |
| self.score = 0.0 |
| self.frame_id = 0 |
| self.frm_state = 0 |
|
|
|
|
| class WindowDetector(object): |
| def __init__( |
| self, |
| window_size_ms, |
| sil_to_speech_time, |
| speech_to_sil_time, |
| frame_size_ms, |
| ): |
| self.window_size_ms = window_size_ms |
| self.sil_to_speech_time = sil_to_speech_time |
| self.speech_to_sil_time = speech_to_sil_time |
| self.frame_size_ms = frame_size_ms |
|
|
| self.win_size_frame = int(window_size_ms / frame_size_ms) |
| self.win_sum = 0 |
| self.win_state = [0] * self.win_size_frame |
|
|
| self.cur_win_pos = 0 |
| self.pre_frame_state = FrameState.kFrameStateSil |
| self.cur_frame_state = FrameState.kFrameStateSil |
| self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) |
| self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) |
|
|
| self.voice_last_frame_count = 0 |
| self.noise_last_frame_count = 0 |
| self.hydre_frame_count = 0 |
|
|
| def reset(self) -> None: |
| self.cur_win_pos = 0 |
| self.win_sum = 0 |
| self.win_state = [0] * self.win_size_frame |
| self.pre_frame_state = FrameState.kFrameStateSil |
| self.cur_frame_state = FrameState.kFrameStateSil |
| self.voice_last_frame_count = 0 |
| self.noise_last_frame_count = 0 |
| self.hydre_frame_count = 0 |
|
|
| def get_win_size(self) -> int: |
| return int(self.win_size_frame) |
|
|
| def detect_one_frame( |
| self, frameState: FrameState, frame_count |
| ) -> AudioChangeState: |
| cur_frame_state = FrameState.kFrameStateSil |
| if frameState == FrameState.kFrameStateSpeech: |
| cur_frame_state = 1 |
| elif frameState == FrameState.kFrameStateSil: |
| cur_frame_state = 0 |
| else: |
| return AudioChangeState.kChangeStateInvalid |
| self.win_sum -= self.win_state[self.cur_win_pos] |
| self.win_sum += cur_frame_state |
| self.win_state[self.cur_win_pos] = cur_frame_state |
| self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame |
|
|
| if ( |
| self.pre_frame_state == FrameState.kFrameStateSil |
| and self.win_sum >= self.sil_to_speech_frmcnt_thres |
| ): |
| self.pre_frame_state = FrameState.kFrameStateSpeech |
| return AudioChangeState.kChangeStateSil2Speech |
|
|
| if ( |
| self.pre_frame_state == FrameState.kFrameStateSpeech |
| and self.win_sum <= self.speech_to_sil_frmcnt_thres |
| ): |
| self.pre_frame_state = FrameState.kFrameStateSil |
| return AudioChangeState.kChangeStateSpeech2Sil |
|
|
| if self.pre_frame_state == FrameState.kFrameStateSil: |
| return AudioChangeState.kChangeStateSil2Sil |
| if self.pre_frame_state == FrameState.kFrameStateSpeech: |
| return AudioChangeState.kChangeStateSpeech2Speech |
| return AudioChangeState.kChangeStateInvalid |
|
|
| def frame_size_ms(self) -> int: |
| return int(self.frame_size_ms) |
|
|
|
|
| class E2EVadModel: |
| def __init__(self, config, vad_post_args: Dict[str, Any], root_dir: Path): |
| super(E2EVadModel, self).__init__() |
| self.vad_opts = VADXOptions(**vad_post_args) |
| self.windows_detector = WindowDetector( |
| self.vad_opts.window_size_ms, |
| self.vad_opts.sil_to_speech_time_thres, |
| self.vad_opts.speech_to_sil_time_thres, |
| self.vad_opts.frame_in_ms, |
| ) |
| self.model = VadOrtInferRuntimeSession(config, root_dir) |
| self.all_reset_detection() |
|
|
| def all_reset_detection(self): |
| |
| self.is_final = False |
| self.data_buf_start_frame = 0 |
| self.frm_cnt = 0 |
| self.latest_confirmed_speech_frame = 0 |
| self.lastest_confirmed_silence_frame = -1 |
| self.continous_silence_frame_count = 0 |
| self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| self.confirmed_start_frame = -1 |
| self.confirmed_end_frame = -1 |
| self.number_end_time_detected = 0 |
| self.sil_frame = 0 |
| self.sil_pdf_ids = self.vad_opts.sil_pdf_ids |
| self.noise_average_decibel = -100.0 |
| self.pre_end_silence_detected = False |
| self.next_seg = True |
|
|
| self.output_data_buf = [] |
| self.output_data_buf_offset = 0 |
| self.frame_probs = [] |
| self.max_end_sil_frame_cnt_thresh = ( |
| self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres |
| ) |
| self.speech_noise_thres = self.vad_opts.speech_noise_thres |
| self.scores = None |
| self.scores_offset = 0 |
| self.max_time_out = False |
| self.decibel = [] |
| self.decibel_offset = 0 |
| self.data_buf_size = 0 |
| self.data_buf_all_size = 0 |
| self.waveform = None |
| self.reset_detection() |
|
|
| def reset_detection(self): |
| self.continous_silence_frame_count = 0 |
| self.latest_confirmed_speech_frame = 0 |
| self.lastest_confirmed_silence_frame = -1 |
| self.confirmed_start_frame = -1 |
| self.confirmed_end_frame = -1 |
| self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected |
| self.windows_detector.reset() |
| self.sil_frame = 0 |
| self.frame_probs = [] |
|
|
| def compute_decibel(self) -> None: |
| frame_sample_length = int( |
| self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 |
| ) |
| frame_shift_length = int( |
| self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000 |
| ) |
| if self.data_buf_all_size == 0: |
| self.data_buf_all_size = len(self.waveform[0]) |
| self.data_buf_size = self.data_buf_all_size |
| else: |
| self.data_buf_all_size += len(self.waveform[0]) |
|
|
| for offset in range( |
| 0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length |
| ): |
| self.decibel.append( |
| 10 |
| * np.log10( |
| np.square( |
| self.waveform[0][offset : offset + frame_sample_length] |
| ).sum() |
| + 1e-6 |
| ) |
| ) |
|
|
| def compute_scores(self, feats: np.ndarray) -> None: |
| scores = self.model(feats) |
| self.vad_opts.nn_eval_block_size = scores[0].shape[1] |
| self.frm_cnt += scores[0].shape[1] |
| if isinstance(feats, list): |
| |
| feats = feats[0] |
|
|
| assert ( |
| scores[0].shape[1] == feats.shape[1] |
| ), "The shape between feats and scores does not match" |
|
|
| self.scores = scores[0] |
| self.scores_offset += self.scores.shape[1] |
|
|
| return scores[1:] |
|
|
| def pop_data_buf_till_frame(self, frame_idx) -> None: |
| while self.data_buf_start_frame < frame_idx: |
| if self.data_buf_size >= int( |
| self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000 |
| ): |
| self.data_buf_start_frame += 1 |
| self.data_buf_size = ( |
| self.data_buf_all_size |
| - self.data_buf_start_frame |
| * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) |
| ) |
|
|
| def pop_data_to_output_buf( |
| self, |
| start_frm, |
| frm_cnt, |
| first_frm_is_start_point: bool, |
| last_frm_is_end_point: bool, |
| end_point_is_sent_end: bool, |
| ) -> None: |
| self.pop_data_buf_till_frame(start_frm) |
| expected_sample_number = int( |
| frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000 |
| ) |
| if last_frm_is_end_point: |
| extra_sample = max( |
| 0, |
| int( |
| self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 |
| - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000 |
| ), |
| ) |
| expected_sample_number += int(extra_sample) |
| if end_point_is_sent_end: |
| expected_sample_number = max(expected_sample_number, self.data_buf_size) |
| if self.data_buf_size < expected_sample_number: |
| logging.error("error in calling pop data_buf\n") |
|
|
| if len(self.output_data_buf) == 0 or first_frm_is_start_point: |
| self.output_data_buf.append(E2EVadSpeechBufWithDoa()) |
| self.output_data_buf[-1].reset() |
| self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms |
| self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms |
| self.output_data_buf[-1].doa = 0 |
| cur_seg = self.output_data_buf[-1] |
| if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| logging.error("warning\n") |
| out_pos = len(cur_seg.buffer) |
| data_to_pop = 0 |
| if end_point_is_sent_end: |
| data_to_pop = expected_sample_number |
| else: |
| data_to_pop = int( |
| frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000 |
| ) |
| if data_to_pop > self.data_buf_size: |
| logging.error("VAD data_to_pop is bigger than self.data_buf.size()!!!\n") |
| data_to_pop = self.data_buf_size |
| expected_sample_number = self.data_buf_size |
|
|
| cur_seg.doa = 0 |
| for sample_cpy_out in range(0, data_to_pop): |
| |
| out_pos += 1 |
| for sample_cpy_out in range(data_to_pop, expected_sample_number): |
| |
| out_pos += 1 |
| if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: |
| logging.error("Something wrong with the VAD algorithm\n") |
| self.data_buf_start_frame += frm_cnt |
| cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms |
| if first_frm_is_start_point: |
| cur_seg.contain_seg_start_point = True |
| if last_frm_is_end_point: |
| cur_seg.contain_seg_end_point = True |
|
|
| def on_silence_detected(self, valid_frame): |
| self.lastest_confirmed_silence_frame = valid_frame |
| if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: |
| self.pop_data_buf_till_frame(valid_frame) |
| |
| |
|
|
| def on_voice_detected(self, valid_frame) -> None: |
| self.latest_confirmed_speech_frame = valid_frame |
| self.pop_data_to_output_buf(valid_frame, 1, False, False, False) |
|
|
| def on_voice_start(self, start_frame, fake_result: bool = False) -> None: |
| if self.vad_opts.do_start_point_detection: |
| pass |
| if self.confirmed_start_frame != -1: |
| logging.error("not reset vad properly\n") |
| else: |
| self.confirmed_start_frame = start_frame |
|
|
| if ( |
| not fake_result |
| and self.vad_state_machine |
| == VadStateMachine.kVadInStateStartPointNotDetected |
| ): |
| self.pop_data_to_output_buf( |
| self.confirmed_start_frame, 1, True, False, False |
| ) |
|
|
| def on_voice_end( |
| self, end_frame, fake_result: bool, is_last_frame: bool |
| ) -> None: |
| for t in range(self.latest_confirmed_speech_frame + 1, end_frame): |
| self.on_voice_detected(t) |
| if self.vad_opts.do_end_point_detection: |
| pass |
| if self.confirmed_end_frame != -1: |
| logging.error("not reset vad properly\n") |
| else: |
| self.confirmed_end_frame = end_frame |
| if not fake_result: |
| self.sil_frame = 0 |
| self.pop_data_to_output_buf( |
| self.confirmed_end_frame, 1, False, True, is_last_frame |
| ) |
| self.number_end_time_detected += 1 |
|
|
| def maybe_on_voice_end_last_frame( |
| self, is_final_frame: bool, cur_frm_idx |
| ) -> None: |
| if is_final_frame: |
| self.on_voice_end(cur_frm_idx, False, True) |
| self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
|
|
| def get_latency(self) -> int: |
| return int(self.latency_frm_num_at_start_point() * self.vad_opts.frame_in_ms) |
|
|
| def latency_frm_num_at_start_point(self) -> int: |
| vad_latency = self.windows_detector.get_win_size() |
| if self.vad_opts.do_extend: |
| vad_latency += int( |
| self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms |
| ) |
| return vad_latency |
|
|
| def get_frame_state(self, t) -> FrameState: |
| frame_state = FrameState.kFrameStateInvalid |
| cur_decibel = self.decibel[t - self.decibel_offset] |
| cur_snr = cur_decibel - self.noise_average_decibel |
| |
| if cur_decibel < self.vad_opts.decibel_thres: |
| frame_state = FrameState.kFrameStateSil |
| self.detect_one_frame(frame_state, t, False) |
| return frame_state |
|
|
| sum_score = 0.0 |
| noise_prob = 0.0 |
| assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num |
| if len(self.sil_pdf_ids) > 0: |
| assert len(self.scores) == 1 |
| sil_pdf_scores = [ |
| self.scores[0][t - self.scores_offset][sil_pdf_id] |
| for sil_pdf_id in self.sil_pdf_ids |
| ] |
| sum_score = sum(sil_pdf_scores) |
| noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio |
| total_score = 1.0 |
| sum_score = total_score - sum_score |
| speech_prob = math.log(sum_score) |
| if self.vad_opts.output_frame_probs: |
| frame_prob = E2EVadFrameProb() |
| frame_prob.noise_prob = noise_prob |
| frame_prob.speech_prob = speech_prob |
| frame_prob.score = sum_score |
| frame_prob.frame_id = t |
| self.frame_probs.append(frame_prob) |
| if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres: |
| if ( |
| cur_snr >= self.vad_opts.snr_thres |
| and cur_decibel >= self.vad_opts.decibel_thres |
| ): |
| frame_state = FrameState.kFrameStateSpeech |
| else: |
| frame_state = FrameState.kFrameStateSil |
| else: |
| frame_state = FrameState.kFrameStateSil |
| if self.noise_average_decibel < -99.9: |
| self.noise_average_decibel = cur_decibel |
| else: |
| self.noise_average_decibel = ( |
| cur_decibel |
| + self.noise_average_decibel |
| * (self.vad_opts.noise_frame_num_used_for_snr - 1) |
| ) / self.vad_opts.noise_frame_num_used_for_snr |
|
|
| return frame_state |
|
|
| def infer_offline( |
| self, |
| feats: np.ndarray, |
| waveform: np.ndarray, |
| in_cache: Dict[str, np.ndarray] = dict(), |
| is_final: bool = False, |
| ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]: |
| self.waveform = waveform |
| self.compute_decibel() |
|
|
| self.compute_scores(feats) |
| if not is_final: |
| self.detect_common_frames() |
| else: |
| self.detect_last_frames() |
| segments = [] |
| for batch_num in range(0, feats.shape[0]): |
| segment_batch = [] |
| if len(self.output_data_buf) > 0: |
| for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| if ( |
| not self.output_data_buf[i].contain_seg_start_point |
| or not self.output_data_buf[i].contain_seg_end_point |
| ): |
| continue |
| segment = [ |
| self.output_data_buf[i].start_ms, |
| self.output_data_buf[i].end_ms, |
| ] |
| segment_batch.append(segment) |
| self.output_data_buf_offset += 1 |
| if segment_batch: |
| segments.append(segment_batch) |
|
|
| if is_final: |
| |
| self.all_reset_detection() |
| return segments, in_cache |
|
|
| def infer_online( |
| self, |
| feats: np.ndarray, |
| waveform: np.ndarray, |
| in_cache: list = None, |
| is_final: bool = False, |
| max_end_sil = 800, |
| ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]: |
| feats = [feats] |
| if in_cache is None: |
| in_cache = [] |
|
|
| self.max_end_sil_frame_cnt_thresh = ( |
| max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| ) |
| self.waveform = waveform |
| feats.extend(in_cache) |
| in_cache = self.compute_scores(feats) |
| self.compute_decibel() |
|
|
| if is_final: |
| self.detect_last_frames() |
| else: |
| self.detect_common_frames() |
|
|
| segments = [] |
| |
| for batch_num in range(0, feats[0].shape[0]): |
| if len(self.output_data_buf) > 0: |
| for i in range(self.output_data_buf_offset, len(self.output_data_buf)): |
| if not self.output_data_buf[i].contain_seg_start_point: |
| continue |
| if ( |
| not self.next_seg |
| and not self.output_data_buf[i].contain_seg_end_point |
| ): |
| continue |
| start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1 |
| if self.output_data_buf[i].contain_seg_end_point: |
| end_ms = self.output_data_buf[i].end_ms |
| self.next_seg = True |
| self.output_data_buf_offset += 1 |
| else: |
| end_ms = -1 |
| self.next_seg = False |
| segments.append([start_ms, end_ms]) |
|
|
| return segments, in_cache |
|
|
| def get_frames_state( |
| self, |
| feats: np.ndarray, |
| waveform: np.ndarray, |
| in_cache: list = None, |
| is_final: bool = False, |
| max_end_sil = 800, |
| ): |
| feats = [feats] |
| states = [] |
| if in_cache is None: |
| in_cache = [] |
|
|
| self.max_end_sil_frame_cnt_thresh = ( |
| max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| ) |
| self.waveform = waveform |
| feats.extend(in_cache) |
| in_cache = self.compute_scores(feats) |
| self.compute_decibel() |
|
|
| if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| return states |
|
|
| for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| frame_state = FrameState.kFrameStateInvalid |
| frame_state = self.get_frame_state(self.frm_cnt - 1 - i) |
| states.append(frame_state) |
| if i == 0 and is_final: |
| logging.info("last frame detected") |
| self.detect_one_frame(frame_state, self.frm_cnt - 1, True) |
| else: |
| self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False) |
|
|
| return states |
|
|
| def detect_common_frames(self) -> int: |
| if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| return 0 |
| for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| frame_state = FrameState.kFrameStateInvalid |
| frame_state = self.get_frame_state(self.frm_cnt - 1 - i) |
| |
| self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False) |
|
|
| self.decibel = self.decibel[self.vad_opts.nn_eval_block_size - 1 :] |
| self.decibel_offset = self.frm_cnt - 1 - i |
| return 0 |
|
|
| def detect_last_frames(self) -> int: |
| if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| return 0 |
| for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): |
| frame_state = FrameState.kFrameStateInvalid |
| frame_state = self.get_frame_state(self.frm_cnt - 1 - i) |
| if i != 0: |
| self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False) |
| else: |
| self.detect_one_frame(frame_state, self.frm_cnt - 1, True) |
|
|
| return 0 |
|
|
| def detect_one_frame( |
| self, cur_frm_state: FrameState, cur_frm_idx, is_final_frame: bool |
| ) -> None: |
| tmp_cur_frm_state = FrameState.kFrameStateInvalid |
| if cur_frm_state == FrameState.kFrameStateSpeech: |
| if math.fabs(1.0) > float(self.vad_opts.fe_prior_thres): |
| tmp_cur_frm_state = FrameState.kFrameStateSpeech |
| else: |
| tmp_cur_frm_state = FrameState.kFrameStateSil |
| elif cur_frm_state == FrameState.kFrameStateSil: |
| tmp_cur_frm_state = FrameState.kFrameStateSil |
| state_change = self.windows_detector.detect_one_frame( |
| tmp_cur_frm_state, cur_frm_idx |
| ) |
| frm_shift_in_ms = self.vad_opts.frame_in_ms |
| if AudioChangeState.kChangeStateSil2Speech == state_change: |
| self.continous_silence_frame_count = 0 |
| self.pre_end_silence_detected = False |
|
|
| if ( |
| self.vad_state_machine |
| == VadStateMachine.kVadInStateStartPointNotDetected |
| ): |
| start_frame = max( |
| self.data_buf_start_frame, |
| cur_frm_idx - self.latency_frm_num_at_start_point(), |
| ) |
| self.on_voice_start(start_frame) |
| self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment |
| for t in range(start_frame + 1, cur_frm_idx + 1): |
| self.on_voice_detected(t) |
| elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx): |
| self.on_voice_detected(t) |
| if ( |
| cur_frm_idx - self.confirmed_start_frame + 1 |
| > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| ): |
| self.on_voice_end(cur_frm_idx, False, False) |
| self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| elif not is_final_frame: |
| self.on_voice_detected(cur_frm_idx) |
| else: |
| self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| else: |
| pass |
| elif AudioChangeState.kChangeStateSpeech2Sil == state_change: |
| self.continous_silence_frame_count = 0 |
| if ( |
| self.vad_state_machine |
| == VadStateMachine.kVadInStateStartPointNotDetected |
| ): |
| pass |
| elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| if ( |
| cur_frm_idx - self.confirmed_start_frame + 1 |
| > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| ): |
| self.on_voice_end(cur_frm_idx, False, False) |
| self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| elif not is_final_frame: |
| self.on_voice_detected(cur_frm_idx) |
| else: |
| self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| else: |
| pass |
| elif AudioChangeState.kChangeStateSpeech2Speech == state_change: |
| self.continous_silence_frame_count = 0 |
| if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| if ( |
| cur_frm_idx - self.confirmed_start_frame + 1 |
| > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| ): |
| self.max_time_out = True |
| self.on_voice_end(cur_frm_idx, False, False) |
| self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| elif not is_final_frame: |
| self.on_voice_detected(cur_frm_idx) |
| else: |
| self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| else: |
| pass |
| elif AudioChangeState.kChangeStateSil2Sil == state_change: |
| self.continous_silence_frame_count += 1 |
| if ( |
| self.vad_state_machine |
| == VadStateMachine.kVadInStateStartPointNotDetected |
| ): |
| |
| if ( |
| ( |
| self.vad_opts.detect_mode |
| == VadDetectMode.kVadSingleUtteranceDetectMode.value |
| ) |
| and ( |
| self.continous_silence_frame_count * frm_shift_in_ms |
| > self.vad_opts.max_start_silence_time |
| ) |
| ) or (is_final_frame and self.number_end_time_detected == 0): |
| for t in range( |
| self.lastest_confirmed_silence_frame + 1, cur_frm_idx |
| ): |
| self.on_silence_detected(t) |
| self.on_voice_start(0, True) |
| self.on_voice_end(0, True, False) |
| self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| else: |
| if cur_frm_idx >= self.latency_frm_num_at_start_point(): |
| self.on_silence_detected( |
| cur_frm_idx - self.latency_frm_num_at_start_point() |
| ) |
| elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: |
| if ( |
| self.continous_silence_frame_count * frm_shift_in_ms |
| >= self.max_end_sil_frame_cnt_thresh |
| ): |
| lookback_frame = int( |
| self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms |
| ) |
| if self.vad_opts.do_extend: |
| lookback_frame -= int( |
| self.vad_opts.lookahead_time_end_point / frm_shift_in_ms |
| ) |
| lookback_frame -= 1 |
| lookback_frame = max(0, lookback_frame) |
| self.on_voice_end(cur_frm_idx - lookback_frame, False, False) |
| self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| elif ( |
| cur_frm_idx - self.confirmed_start_frame + 1 |
| > self.vad_opts.max_single_segment_time / frm_shift_in_ms |
| ): |
| self.on_voice_end(cur_frm_idx, False, False) |
| self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected |
| elif self.vad_opts.do_extend and not is_final_frame: |
| if self.continous_silence_frame_count <= int( |
| self.vad_opts.lookahead_time_end_point / frm_shift_in_ms |
| ): |
| self.on_voice_detected(cur_frm_idx) |
| else: |
| self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx) |
| else: |
| pass |
|
|
| if ( |
| self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected |
| and self.vad_opts.detect_mode |
| == VadDetectMode.kVadMutipleUtteranceDetectMode.value |
| ): |
| self.reset_detection() |
|
|
|
|
| class FSMNVad(object): |
| def __init__(self, config_dir): |
| config_dir = Path(config_dir) |
| self.config = read_yaml(config_dir / "fsmn-config.yaml") |
| self.frontend = WavFrontend( |
| cmvn_file=config_dir / "fsmn-am.mvn", |
| **self.config["WavFrontend"]["frontend_conf"], |
| ) |
| self.config["FSMN"]["model_path"] = config_dir / "fsmnvad-offline.onnx" |
|
|
| self.vad = E2EVadModel( |
| self.config["FSMN"], self.config["vadPostArgs"], config_dir |
| ) |
|
|
| def set_parameters(self, mode): |
| pass |
|
|
| def extract_feature(self, waveform): |
| fbank, _ = self.frontend.fbank(waveform) |
| feats, feats_len = self.frontend.lfr_cmvn(fbank) |
| return feats.astype(np.float32), feats_len |
|
|
| def is_speech(self, buf, sample_rate=16000): |
| assert sample_rate == 16000, "only support 16k sample rate" |
|
|
| def segments_offline(self, waveform_path: Union[str, Path, np.ndarray]): |
| """get sements of audio""" |
|
|
| if isinstance(waveform_path, np.ndarray): |
| waveform = waveform_path |
| else: |
| if not os.path.exists(waveform_path): |
| raise FileExistsError(f"{waveform_path} is not exist.") |
| if os.path.isfile(waveform_path): |
| logging.info(f"load audio {waveform_path}") |
| waveform, _sample_rate = sf.read( |
| waveform_path, |
| dtype="float32", |
| ) |
| else: |
| raise FileNotFoundError(str(Path)) |
| assert ( |
| _sample_rate == 16000 |
| ), f"only support 16k sample rate, current sample rate is {_sample_rate}" |
|
|
| feats, feats_len = self.extract_feature(waveform) |
| waveform = waveform[None, ...] |
| segments_part, in_cache = self.vad.infer_offline( |
| feats[None, ...], waveform, is_final=True |
| ) |
| return segments_part[0] |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
| logging.basicConfig(format=formatter, level=logging.INFO) |
|
|
| def parse_args(): |
| arg_parser = argparse.ArgumentParser(description="Sense Voice") |
| arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model") |
| download_model_path = os.path.dirname(__file__) |
| arg_parser.add_argument( |
| "-dp", |
| "--download_path", |
| default=download_model_path, |
| type=str, |
| help="dir path of resource downloaded" |
| ) |
| arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device") |
| arg_parser.add_argument( |
| "-n", "--num_threads", default=4, type=int, help="Num threads" |
| ) |
| arg_parser.add_argument( |
| "-l", |
| "--language", |
| choices=languages.keys(), |
| default="auto", |
| type=str, |
| help="Language" |
| ) |
| arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN") |
| return arg_parser.parse_args() |
|
|
| def main(audio_file, download_path, device, num_threads, language, use_itn): |
| front = WavFrontend(os.path.join(download_path, "am.mvn")) |
|
|
| model = SenseVoiceInferenceSession( |
| os.path.join(download_path, "embedding.npy"), |
| os.path.join( |
| download_path, |
| "sense-voice-encoder.rknn", |
| ), |
| os.path.join(download_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"), |
| device, |
| num_threads, |
| ) |
| waveform, _sample_rate = sf.read( |
| audio_file, |
| dtype="float32", |
| always_2d=True |
| ) |
|
|
| logging.info(f"Audio {audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel") |
| |
| start = time.time() |
| vad = FSMNVad(download_path) |
| for channel_id, channel_data in enumerate(waveform.T): |
| segments = vad.segments_offline(channel_data) |
| results = "" |
| for part in segments: |
| audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16]) |
| asr_result = model( |
| audio_feats[None, ...], |
| language=languages[language], |
| use_itn=use_itn, |
| ) |
| logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}") |
| results += asr_result |
| logging.info(f"Results: {results}") |
| vad.vad.all_reset_detection() |
| decoding_time = time.time() - start |
| logging.info(f"Decoder audio takes {decoding_time} seconds") |
| logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.") |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args.audio_file, args.download_path, args.device, args.num_threads, args.language, args.use_itn) |
|
|
|
|