| import logging |
| import os |
| import onnx |
| import tensorrt as trt |
| from typing import List |
| from collections import OrderedDict |
| from onnx import shape_inference |
|
|
|
|
| def vit_tagging_t2t(input_path="simple_model.onnx",output_path="vit.trt"): |
| model = onnx.load(input_path) |
| inferred_model = shape_inference.infer_shapes(model) |
| |
| simplified_model = input_path |
| bitmask = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| |
| trt_logger = trt.Logger() |
| all_count,mix_count=0,0 |
| with trt.Builder(trt_logger) as builder, builder.create_network(bitmask) as network, builder.create_builder_config() as config, trt.OnnxParser(network, trt_logger) as parser: |
| |
| config.set_flag(trt.BuilderFlag.FP16) |
| with open(simplified_model, 'rb') as f: |
| success = parser.parse(f.read()) |
| if not success: |
| for idx in range(parser.num_errors): |
| print(parser.get_error(idx)) |
| raise RuntimeError("Failed to parse the ONNX file.") |
| profile = builder.create_optimization_profile() |
| min_shape = [3,224,224] |
| max_shape = [3,224,224] |
| opt_shape = max_shape |
| profile.set_shape("input", |
| min=(1, *min_shape), |
| opt=(70, *opt_shape), |
| max=(70, *max_shape)) |
|
|
| config.add_optimization_profile(profile) |
| """ |
| for i in range(network.num_layers): |
| all_count+=1 |
| layer = network.get_layer(i) |
| if "ReduceMean" in layer.name or "Pow" in layer.name: |
| mix_count+=1 |
| config.set_flag(trt.BuilderFlag.STRICT_TYPES) |
| layer.precision = trt.float32 |
| layer.set_output_type(0, trt.float32) |
| """ |
| |
| network.get_input(0).dtype = trt.float32 |
| network.get_output(0).dtype = trt.float32 |
|
|
| print(all_count,mix_count) |
| engine = builder.build_engine(network, config) |
| |
| with open(output_path, 'wb') as f: |
| f.write(engine.serialize()) |
| f.close() |
| |
| if __name__=="__main__": |
| vit_tagging_t2t() |