File size: 6,815 Bytes
d21d362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#include <stdexcept>
#include <cmath>
#include <iostream>

#include "onnx_wrapper.h"

static void get_input_names(Ort::Session* session, std::vector<std::string> &input_names_str,
                   std::vector<const char *> &input_names_char) {
    Ort::AllocatorWithDefaultOptions allocator;
    size_t nodes_num = session->GetInputCount();
    input_names_str.resize(nodes_num);
    input_names_char.resize(nodes_num);
    for (size_t i = 0; i != nodes_num; ++i) {    
        auto t = session->GetInputNameAllocated(i, allocator);
        input_names_str[i] = t.get();
        input_names_char[i] = input_names_str[i].c_str();
    }
}

static void get_output_names(Ort::Session* session, std::vector<std::string> &output_names_,
                   std::vector<const char *> &vad_out_names_) {
    Ort::AllocatorWithDefaultOptions allocator;
    size_t nodes_num = session->GetOutputCount();
    output_names_.resize(nodes_num);
    vad_out_names_.resize(nodes_num);
    for (size_t i = 0; i != nodes_num; ++i) {    
        auto t = session->GetOutputNameAllocated(i, allocator);
        output_names_[i] = t.get();
        vad_out_names_[i] = output_names_[i].c_str();
    }
}

OnnxVadWrapper::OnnxVadWrapper(const std::string& model_path, bool force_cpu, int thread_num)
    : sample_rates_{16000}, model_path_(model_path) {
    Ort::SessionOptions session_options;
    session_options.SetIntraOpNumThreads(thread_num);
    session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
    session_options.DisableCpuMemArena();

    // if (force_cpu && supports_cpu()) {
    //     session_options.AppendExecutionProvider_CPU();
    // }

    // 初始化 ONNX Session
    try {
        env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "OnnxVadWrapper");
        session_ = std::make_unique<Ort::Session>(env_, ORTCHAR(model_path.c_str()), session_options);
        std::cout << "Successfully load model from " << model_path << std::endl;
    } catch (std::exception const &e) {
        std::cout << "Error when load vad onnx model: " << e.what() << std::endl;
        exit(-1);
    }

    get_input_names(session_.get(), input_names_, vad_in_names_);
    get_output_names(session_.get(), output_names_, vad_out_names_);

    reset_states();
}

OnnxVadWrapper::~OnnxVadWrapper() = default;

void OnnxVadWrapper::reset_states(int batch_size) {
    int total_size = 2 * batch_size * 128;
    state_.resize(total_size); /////
    state_.assign(state_.size(), 0.0f);
    context_.clear();
    last_sr_ = 0;
    last_batch_size_ = 0;
}

std::pair<std::vector<float>, std::vector<float>> OnnxVadWrapper::operator()(const std::vector<float>& x, int sr) {
    validate_input(x, sr);

    int num_samples = (sr == 16000) ? 512 : 256;
    int context_size = (sr == 16000) ? 64 : 32;

    int batch_size = 1;  // 假设单通道输入
    if (x.size() != num_samples) {
        throw std::invalid_argument("Input must be exactly " + std::to_string(num_samples) + " samples.");
    }

    if (!last_batch_size_) reset_states(batch_size);
    if (last_sr_ != 0 && last_sr_ != sr) reset_states(batch_size);
    if (last_batch_size_ != 0 && last_batch_size_ != batch_size) reset_states(batch_size);

    if (context_.empty()) {
        context_.resize(batch_size * context_size, 0.0f);
    }

    // 合并 context 和 input
    std::vector<float> x_with_context(context_.begin(), context_.end());
    x_with_context.insert(x_with_context.end(), x.begin(), x.end());

    // Prepare inputs
    std::vector<Ort::Value> inputs;
    auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
    std::array<int64_t, 3> input_shape = {1, 1, static_cast<int64_t>(x_with_context.size())};
    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
        mem_info, const_cast<float*>(x_with_context.data()), x_with_context.size(),
        input_shape.data(), input_shape.size());
    inputs.emplace_back(std::move(input_tensor));

    std::array<int64_t, 3> state_shape = {2, batch_size, 128};
    Ort::Value state_tensor = Ort::Value::CreateTensor<float>(
        mem_info, state_.data(), state_.size(), state_shape.data(), state_shape.size());
    inputs.emplace_back(std::move(state_tensor));

    std::array<int64_t, 1> sr_shape = {1};
    float sr_f = static_cast<float>(sr);
    Ort::Value sr_tensor = Ort::Value::CreateTensor<float>(
        mem_info, &sr_f, 1, sr_shape.data(), sr_shape.size());
    inputs.emplace_back(std::move(sr_tensor));

    // const char* input_names[] = {"input", "state", "sr"};
    // std::vector<Ort::Value> inputs = {std::move(input_tensor), std::move(state_tensor), std::move(sr_tensor)};

    // Run inference
    std::vector<Ort::Value> outputs;
    try {
        outputs = session_->Run(
                Ort::RunOptions{nullptr}, vad_in_names_.data(), inputs.data(),
                inputs.size(), vad_out_names_.data(), vad_out_names_.size());
    } catch (std::exception const &e) {
        std::cout << "Error when run vad onnx forword: " << e.what() << std::endl;
        exit(-1);
    }

    // Get output
    float* out_data = outputs[0].GetTensorMutableData<float>();
    size_t out_len = outputs[0].GetTensorTypeAndShapeInfo().GetElementCount();
    std::vector<float> out(out_data, out_data + out_len);

    // Update state and context
    float* new_state = outputs[1].GetTensorMutableData<float>();
    std::copy(new_state, new_state + state_.size(), state_.begin());

    context_.assign(x_with_context.end() - context_size, x_with_context.end());

    last_sr_ = sr;
    last_batch_size_ = batch_size;

    return {out, {}};
}

std::vector<float> OnnxVadWrapper::audio_forward(const std::vector<float>& audio, int sr) {
    std::vector<float> x = audio;
    reset_states();

    int num_samples = (sr == 16000) ? 512 : 256;
    std::vector<float> result;

    // Pad to multiple of num_samples
    int pad_num = (num_samples - (x.size() % num_samples)) % num_samples;
    x.resize(x.size() + pad_num, 0.0f);

    for (size_t i = 0; i < x.size(); i += num_samples) {
        std::vector<float> chunk(x.begin() + i, x.begin() + i + num_samples);
        auto [out, _] = (*this)(chunk, sr);
        result.insert(result.end(), out.begin(), out.end());
    }

    return result;
}

bool OnnxVadWrapper::supports_cpu() {
    auto providers = Ort::GetAvailableProviders();

    for (const std::string& provider : providers) {
        if (provider == "CPUExecutionProvider") {
            return true;
        }
    }

    return false;
}

void OnnxVadWrapper::validate_input(const std::vector<float>& x, int sr) {
    if (sr != 16000 && sr % 16000 != 0) {
        throw std::invalid_argument("Unsupported sampling rate: " + std::to_string(sr));
    }

    if ((sr / x.size()) > 31.25) {
        throw std::invalid_argument("Input audio chunk is too short");
    }
}