| import os, psutil |
|
|
| os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=True)) |
| os.environ["OMP_WAIT_POLICY"] = "ACTIVE" |
|
|
|
|
| from onnxruntime import ( |
| GraphOptimizationLevel, |
| InferenceSession, |
| SessionOptions, |
| ExecutionMode, |
| ) |
|
|
|
|
| def get_onnx_runtime_sessions( |
| model_paths, |
| default: bool = True, |
| opt_level: int = 99, |
| parallel_exe_mode: bool = True, |
| n_threads: int = 0, |
| provider=[ |
| "CPUExecutionProvider", |
| ], |
| ) -> InferenceSession: |
| """ |
| Optimizes the model |
| |
| Args: |
| model_paths (List or Tuple of str) : the path to, in order: |
| path_to_encoder (str) : the path of input onnx encoder model. |
| path_to_decoder (str) : the path of input onnx decoder model. |
| path_to_initial_decoder (str) : the path of input initial onnx decoder model. |
| default : set this to true, ort will choose the best settings for your hardware. |
| (you can test out different settings for better results.) |
| opt_level (int) : sess_options.GraphOptimizationLevel param if set 1 uses 'ORT_ENABLE_BASIC', |
| 2 for 'ORT_ENABLE_EXTENDED' and 99 for 'ORT_ENABLE_ALL', |
| default value is set to 99. |
| parallel_exe_mode (bool) : Sets the execution mode. Default is True (parallel). |
| n_threads (int) : Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose |
| provider : execution providers list. |
| |
| Returns: |
| encoder_session : encoder onnx InferenceSession |
| decoder_session : decoder onnx InferenceSession |
| decoder_sess_init : initial decoder onnx InferenceSession |
| |
| """ |
| path_to_encoder, path_to_decoder, path_to_initial_decoder = model_paths |
|
|
| if default: |
|
|
| encoder_sess = InferenceSession(str(path_to_encoder)) |
|
|
| decoder_sess = InferenceSession(str(path_to_decoder)) |
|
|
| decoder_sess_init = InferenceSession(str(path_to_initial_decoder)) |
|
|
| else: |
|
|
| |
| options = SessionOptions() |
|
|
| if opt_level == 1: |
| options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC |
| elif opt_level == 2: |
| options.graph_optimization_level = ( |
| GraphOptimizationLevel.ORT_ENABLE_EXTENDED |
| ) |
| else: |
| assert opt_level == 99 |
| options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
| |
| if parallel_exe_mode == True: |
| options.execution_mode = ExecutionMode.ORT_PARALLEL |
| else: |
| options.execution_mode = ExecutionMode.ORT_SEQUENTIAL |
|
|
| options.intra_op_num_threads = n_threads |
| |
|
|
| |
|
|
| encoder_sess = InferenceSession( |
| str(path_to_encoder), options, providers=provider |
| ) |
|
|
| decoder_sess = InferenceSession( |
| str(path_to_decoder), options, providers=provider |
| ) |
|
|
| decoder_sess_init = InferenceSession( |
| str(path_to_initial_decoder), options, providers=provider |
| ) |
|
|
| return encoder_sess, decoder_sess, decoder_sess_init |
|
|