| import onnxscript |
| import onnx_ir as ir |
| import onnx_ir.passes.common |
| import numpy as np |
| import onnxslim |
|
|
|
|
| class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase): |
| def pattern(self, op, x, dft_length): |
| x = op.Reshape(x, _allow_other_inputs=True) |
| dft = op.DFT(x, dft_length, _outputs=["dft_output"]) |
| real_part = op.Slice(dft, [0], [1], [-1]) |
| return op.Squeeze(real_part, [-1]) |
|
|
| def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value): |
| |
| dft_node = dft_output.producer() |
| assert dft_node is not None |
|
|
| dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item() |
|
|
| |
| |
| |
| num_freqs = dft_size // 2 + 1 |
|
|
| |
| n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] |
| k = np.arange(num_freqs, dtype=np.float32)[ |
| np.newaxis, : |
| ] |
| dft_matrix = np.cos( |
| 2 * np.pi * k * n / dft_size |
| ) |
|
|
| |
| dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix") |
|
|
| |
| result = op.MatMul(x, dft_matrix) |
|
|
| return result |
|
|
|
|
| class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase): |
| def pattern(self, op, x): |
| return op.Split( |
| x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"] |
| ) |
|
|
| def rewrite(self, op, x: ir.Value, **kwargs): |
| zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero") |
| batch_size = op.Gather(x, zero) |
| sample_size = op.initializer( |
| ir.tensor(np.array([144000], dtype=np.int32)), "sample_size" |
| ) |
| return batch_size, sample_size |
|
|
|
|
| class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase): |
| def pattern(self, op, x): |
| return op.Cast(x) |
|
|
| def rewrite(self, op, x: ir.Value, **kwargs): |
| return op.Identity(x) |
|
|
|
|
| class RemoveReversedSequenceFork(onnxscript.rewriter.RewriteRuleClassBase): |
| def pattern(self, op, x, y, scale, bias): |
| x = op.Transpose(x) |
| y = op.Transpose(y) |
| x = op.ReverseSequence(x, _allow_other_inputs=True) |
| y = op.ReverseSequence(y, _allow_other_inputs=True) |
| x = op.Unsqueeze(x, _allow_other_inputs=True) |
| y = op.Unsqueeze(y, _allow_other_inputs=True) |
| concat = op.Concat(x, y) |
| mul = op.Mul(concat, scale) |
| add = op.Add(mul, bias) |
| return op.Transpose(add) |
|
|
| def rewrite(self, op, x, y, scale, bias, **kwargs): |
| |
| neg_one = op.initializer(ir.tensor(np.array([-1], dtype=np.int64)), "neg_one") |
| int_64_min = op.initializer( |
| ir.tensor(np.array([-9223372036854775808], dtype=np.int64)), "int_64_min" |
| ) |
| |
| x = op.Slice(x, neg_one, int_64_min, neg_one, neg_one) |
| y = op.Slice(y, neg_one, int_64_min, neg_one, neg_one) |
| x = op.Unsqueeze(x, neg_one) |
| y = op.Unsqueeze(y, neg_one) |
| concat = op.Concat(x, y, axis=3) |
| |
| mul = op.Mul(concat, scale) |
| add = op.Add(mul, bias) |
| return op.Transpose(add, perm=[0, 3, 2, 1]) |
|
|
|
|
| model = ir.load("model.onnx") |
|
|
| |
| model.graph.inputs[0].shape = ir.Shape(["batch", 144000]) |
| model.graph.outputs[0].shape = ir.Shape(["batch", 6522]) |
|
|
| onnxscript.rewriter.rewrite( |
| model, |
| [ |
| ReplaceDftWithMatMulRule().rule(), |
| ReplaceSplit().rule(), |
| RemoveCast().rule(), |
| ], |
| ) |
|
|
| |
| initializers = list(model.graph.initializers.values()) |
| for initializer in initializers: |
| if initializer.dtype == ir.DataType.INT32: |
| int32_array = initializer.const_value.numpy() |
| int64_array = int32_array.astype(np.int64) |
| new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array)) |
| model.graph.initializers.pop(initializer.name) |
| model.graph.initializers.add(new_initializer) |
| initializer.replace_all_uses_with(new_initializer) |
|
|
| onnxscript.optimizer.optimize( |
| model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 |
| ) |
|
|
|
|
| |
| def remove_slice_reshape(model: ir.Model): |
| mul_node = model.graph.node("model/MEL_SPEC1/Mul") |
| first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1") |
| first_shape = ir.val( |
| "first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64) |
| ) |
| model.graph.initializers.add(first_shape) |
| second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1") |
| second_shape = ir.val( |
| "second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64) |
| ) |
| model.graph.initializers.add(second_shape) |
|
|
| third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4") |
| third_shape = ir.val( |
| "third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64) |
| ) |
| model.graph.initializers.add(third_shape) |
| fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4") |
| fourth_shape = ir.val( |
| "fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64) |
| ) |
| model.graph.initializers.add(fourth_shape) |
|
|
| |
| first_reshape.replace_input_with(0, mul_node.outputs[0]) |
| first_reshape.replace_input_with(1, first_shape) |
| second_reshape.replace_input_with(0, mul_node.outputs[0]) |
| second_reshape.replace_input_with(1, second_shape) |
| third_reshape.replace_input_with(1, third_shape) |
| fourth_reshape.replace_input_with(1, fourth_shape) |
|
|
|
|
| remove_slice_reshape(model) |
| |
| onnxscript.optimizer.optimize( |
| model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 |
| ) |
|
|
| print("Slimming model...") |
| model = ir.from_proto(onnxslim.slim(ir.to_proto(model))) |
|
|
| print("Removing reversed sequence fork...") |
| onnxscript.rewriter.rewrite( |
| model, |
| [ |
| RemoveReversedSequenceFork.rule(), |
| ], |
| ) |
|
|
| |
| model = ir.from_proto(onnxslim.slim(ir.to_proto(model))) |
|
|
| onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model) |
| model.graph.inputs[0].name = "input" |
| model.graph.outputs[0].name = "output" |
| model.ir_version = 10 |
| model.producer_name = "onnx-ir" |
| model.graph.name = "BirdNET-v2.4" |
|
|
| ir.save(model, "birdnet.onnx") |
|
|