| #include "Encoder.hpp" |
| #include "DecoderMain.hpp" |
| #include "DecoderLoop.hpp" |
|
|
| #include <stdio.h> |
| #include <ctime> |
| #include <sys/time.h> |
|
|
| #include <ax_sys_api.h> |
|
|
| static double get_current_time() |
| { |
| struct timeval tv; |
| gettimeofday(&tv, NULL); |
|
|
| return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; |
| } |
|
|
| int main(int argc, char** argv) { |
| int ret = AX_SYS_Init(); |
| if (0 != ret) { |
| fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret); |
| return -1; |
| } |
|
|
| AX_ENGINE_NPU_ATTR_T npu_attr; |
| memset(&npu_attr, 0, sizeof(npu_attr)); |
| npu_attr.eHardMode = static_cast<AX_ENGINE_NPU_MODE_T>(0); |
| ret = AX_ENGINE_Init(&npu_attr); |
| if (0 != ret) { |
| fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret); |
| return -1; |
| } |
|
|
| Encoder encoder; |
| DecoderMain decoder_main; |
| DecoderLoop decoder_loop; |
|
|
| double start, end; |
| double whole_start, whole_end; |
|
|
| start = get_current_time(); |
| if (0 != encoder.Init("../axmodel/encoder.axmodel")) { |
| printf("Init encoder failed!\n"); |
| return -1; |
| } |
| end = get_current_time(); |
| printf("Load encoder take %.2fms\n", end - start); |
|
|
| start = get_current_time(); |
| if (0 != decoder_main.Init("../axmodel/decoder_main.axmodel")) { |
| printf("Init decoder_main failed!\n"); |
| return -1; |
| } |
| end = get_current_time(); |
| printf("Load decoder_main take %.2fms\n", end - start); |
|
|
| start = get_current_time(); |
| if (0 != decoder_loop.Init("../axmodel/decoder_loop.axmodel")) { |
| printf("Init decoder_loop failed!\n"); |
| return -1; |
| } |
| end = get_current_time(); |
| printf("Load decoder_loop take %.2fms\n", end - start); |
|
|
| std::vector<float> encoder_inputs(encoder.GetInputSize(0) / sizeof(float)); |
| std::vector<float> encoder_input_lengths(encoder.GetInputSize(1) / sizeof(float)); |
| encoder_input_lengths[0] = 100; |
|
|
| std::vector<float> n_layer_cross_k(encoder.GetOutputSize(0) / sizeof(float)); |
| std::vector<float> n_layer_cross_v(encoder.GetOutputSize(1) / sizeof(float)); |
| std::vector<float> cross_attn_mask(encoder.GetOutputSize(2) / sizeof(float)); |
|
|
| start = get_current_time(); |
| whole_start = start; |
| encoder.SetInput(encoder_inputs.data(), 0); |
| encoder.SetInput(encoder_input_lengths.data(), 1); |
| encoder.Run(); |
| |
| |
| |
| end = get_current_time(); |
| printf("Run encoder take %.2fms\n", end - start); |
|
|
| std::vector<int> tokens(decoder_main.GetInputSize(0) / sizeof(int)); |
|
|
| std::vector<int> logits(decoder_main.GetOutputSize(0) / sizeof(int)); |
| std::vector<float> n_layer_self_k_cache(decoder_main.GetOutputSize(1) / sizeof(float)); |
| std::vector<float> n_layer_self_v_cache(decoder_main.GetOutputSize(2) / sizeof(float)); |
|
|
| start = get_current_time(); |
| decoder_main.SetInput(tokens.data(), 0); |
| |
| |
| |
| decoder_main.SetInput(n_layer_cross_k.data(), 1); |
| decoder_main.SetInput(n_layer_cross_v.data(), 2); |
| decoder_main.SetInput(cross_attn_mask.data(), 3); |
| decoder_main.Run(); |
| decoder_main.GetOutput(logits.data(), 0); |
| |
| |
| end = get_current_time(); |
| printf("Run decoder_main take %.2fms\n", end - start); |
|
|
| std::vector<float> pe(decoder_loop.GetOutputSize(5) / sizeof(float)); |
| std::vector<float> self_attn_mask(decoder_loop.GetOutputSize(6) / sizeof(float)); |
|
|
| decoder_loop.SetInput(n_layer_cross_k.data(), 3); |
| decoder_loop.SetInput(n_layer_cross_v.data(), 4); |
| for (int i = 0; i < 14; i++) { |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| start = get_current_time(); |
| decoder_loop.SetInput(tokens.data(), 0); |
| decoder_loop.SetInput(decoder_loop.GetOutputPtr(1), 1); |
| decoder_loop.SetInput(decoder_loop.GetOutputPtr(2), 2); |
| |
| |
| |
| decoder_loop.Run(); |
| decoder_loop.GetOutput(logits.data(), 0); |
| |
| |
| end = get_current_time(); |
| printf("Run decoder_loop take %.2fms\n", end - start); |
| } |
|
|
| whole_end = get_current_time(); |
| printf("Whole duration %.2fms\n", whole_end - whole_start); |
| printf("RTF: %.4f\n", (whole_end - whole_start) / 4000.0); |
|
|
| return 0; |
| } |