| #include <stdexcept> |
| #include <cmath> |
| #include <iostream> |
|
|
| #include "onnx_wrapper.h" |
|
|
| static void get_input_names(Ort::Session* session, std::vector<std::string> &input_names_str, |
| std::vector<const char *> &input_names_char) { |
| Ort::AllocatorWithDefaultOptions allocator; |
| size_t nodes_num = session->GetInputCount(); |
| input_names_str.resize(nodes_num); |
| input_names_char.resize(nodes_num); |
| for (size_t i = 0; i != nodes_num; ++i) { |
| auto t = session->GetInputNameAllocated(i, allocator); |
| input_names_str[i] = t.get(); |
| input_names_char[i] = input_names_str[i].c_str(); |
| } |
| } |
|
|
| static void get_output_names(Ort::Session* session, std::vector<std::string> &output_names_, |
| std::vector<const char *> &vad_out_names_) { |
| Ort::AllocatorWithDefaultOptions allocator; |
| size_t nodes_num = session->GetOutputCount(); |
| output_names_.resize(nodes_num); |
| vad_out_names_.resize(nodes_num); |
| for (size_t i = 0; i != nodes_num; ++i) { |
| auto t = session->GetOutputNameAllocated(i, allocator); |
| output_names_[i] = t.get(); |
| vad_out_names_[i] = output_names_[i].c_str(); |
| } |
| } |
|
|
| OnnxVadWrapper::OnnxVadWrapper(const std::string& model_path, bool force_cpu, int thread_num) |
| : sample_rates_{16000}, model_path_(model_path) { |
| Ort::SessionOptions session_options; |
| session_options.SetIntraOpNumThreads(thread_num); |
| session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); |
| session_options.DisableCpuMemArena(); |
|
|
| |
| |
| |
|
|
| |
| try { |
| env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "OnnxVadWrapper"); |
| session_ = std::make_unique<Ort::Session>(env_, ORTCHAR(model_path.c_str()), session_options); |
| std::cout << "Successfully load model from " << model_path << std::endl; |
| } catch (std::exception const &e) { |
| std::cout << "Error when load vad onnx model: " << e.what() << std::endl; |
| exit(-1); |
| } |
|
|
| get_input_names(session_.get(), input_names_, vad_in_names_); |
| get_output_names(session_.get(), output_names_, vad_out_names_); |
|
|
| reset_states(); |
| } |
|
|
| OnnxVadWrapper::~OnnxVadWrapper() = default; |
|
|
| void OnnxVadWrapper::reset_states(int batch_size) { |
| int total_size = 2 * batch_size * 128; |
| state_.resize(total_size); |
| state_.assign(state_.size(), 0.0f); |
| context_.clear(); |
| last_sr_ = 0; |
| last_batch_size_ = 0; |
| } |
|
|
| std::pair<std::vector<float>, std::vector<float>> OnnxVadWrapper::operator()(const std::vector<float>& x, int sr) { |
| validate_input(x, sr); |
|
|
| int num_samples = (sr == 16000) ? 512 : 256; |
| int context_size = (sr == 16000) ? 64 : 32; |
|
|
| int batch_size = 1; |
| if (x.size() != num_samples) { |
| throw std::invalid_argument("Input must be exactly " + std::to_string(num_samples) + " samples."); |
| } |
|
|
| if (!last_batch_size_) reset_states(batch_size); |
| if (last_sr_ != 0 && last_sr_ != sr) reset_states(batch_size); |
| if (last_batch_size_ != 0 && last_batch_size_ != batch_size) reset_states(batch_size); |
|
|
| if (context_.empty()) { |
| context_.resize(batch_size * context_size, 0.0f); |
| } |
|
|
| |
| std::vector<float> x_with_context(context_.begin(), context_.end()); |
| x_with_context.insert(x_with_context.end(), x.begin(), x.end()); |
|
|
| |
| std::vector<Ort::Value> inputs; |
| auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); |
| std::array<int64_t, 3> input_shape = {1, 1, static_cast<int64_t>(x_with_context.size())}; |
| Ort::Value input_tensor = Ort::Value::CreateTensor<float>( |
| mem_info, const_cast<float*>(x_with_context.data()), x_with_context.size(), |
| input_shape.data(), input_shape.size()); |
| inputs.emplace_back(std::move(input_tensor)); |
|
|
| std::array<int64_t, 3> state_shape = {2, batch_size, 128}; |
| Ort::Value state_tensor = Ort::Value::CreateTensor<float>( |
| mem_info, state_.data(), state_.size(), state_shape.data(), state_shape.size()); |
| inputs.emplace_back(std::move(state_tensor)); |
|
|
| std::array<int64_t, 1> sr_shape = {1}; |
| float sr_f = static_cast<float>(sr); |
| Ort::Value sr_tensor = Ort::Value::CreateTensor<float>( |
| mem_info, &sr_f, 1, sr_shape.data(), sr_shape.size()); |
| inputs.emplace_back(std::move(sr_tensor)); |
|
|
| |
| |
|
|
| |
| std::vector<Ort::Value> outputs; |
| try { |
| outputs = session_->Run( |
| Ort::RunOptions{nullptr}, vad_in_names_.data(), inputs.data(), |
| inputs.size(), vad_out_names_.data(), vad_out_names_.size()); |
| } catch (std::exception const &e) { |
| std::cout << "Error when run vad onnx forword: " << e.what() << std::endl; |
| exit(-1); |
| } |
|
|
| |
| float* out_data = outputs[0].GetTensorMutableData<float>(); |
| size_t out_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount(); |
| std::vector<float> out(out_data, out_data + out_len); |
|
|
| |
| float* new_state = outputs[1].GetTensorMutableData<float>(); |
| std::copy(new_state, new_state + state_.size(), state_.begin()); |
|
|
| context_.assign(x_with_context.end() - context_size, x_with_context.end()); |
|
|
| last_sr_ = sr; |
| last_batch_size_ = batch_size; |
|
|
| return {out, {}}; |
| } |
|
|
| std::vector<float> OnnxVadWrapper::audio_forward(const std::vector<float>& audio, int sr) { |
| std::vector<float> x = audio; |
| reset_states(); |
|
|
| int num_samples = (sr == 16000) ? 512 : 256; |
| std::vector<float> result; |
|
|
| |
| int pad_num = (num_samples - (x.size() % num_samples)) % num_samples; |
| x.resize(x.size() + pad_num, 0.0f); |
|
|
| for (size_t i = 0; i < x.size(); i += num_samples) { |
| std::vector<float> chunk(x.begin() + i, x.begin() + i + num_samples); |
| auto [out, _] = (*this)(chunk, sr); |
| result.insert(result.end(), out.begin(), out.end()); |
| } |
|
|
| return result; |
| } |
|
|
| bool OnnxVadWrapper::supports_cpu() { |
| auto providers = Ort::GetAvailableProviders(); |
|
|
| for (const std::string& provider : providers) { |
| if (provider == "CPUExecutionProvider") { |
| return true; |
| } |
| } |
|
|
| return false; |
| } |
|
|
| void OnnxVadWrapper::validate_input(const std::vector<float>& x, int sr) { |
| if (sr != 16000 && sr % 16000 != 0) { |
| throw std::invalid_argument("Unsupported sampling rate: " + std::to_string(sr)); |
| } |
|
|
| if ((sr / x.size()) > 31.25) { |
| throw std::invalid_argument("Input audio chunk is too short"); |
| } |
| } |
|
|