| |
| |
|
|
| """ |
| Script to compare the results of an ONNX model with a TFLite model given the same input. |
| Optionally also compare with Tract runtime for ONNX. |
| Created by Copilot. |
| |
| Usage: |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npy |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-5 --atol 1e-5 |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark |
| """ |
|
|
| import argparse |
| import time |
| import numpy as np |
| import onnxruntime as ort |
| import tensorflow as tf |
| from typing import Dict, List, Tuple, Optional, Any |
|
|
| try: |
| import tract |
|
|
| TRACT_AVAILABLE = True |
| except ImportError: |
| TRACT_AVAILABLE = False |
|
|
|
|
| def load_onnx_model(onnx_path: str) -> ort.InferenceSession: |
| """Load an ONNX model and return an inference session.""" |
| print(f"Loading ONNX model from: {onnx_path}") |
| session = ort.InferenceSession(onnx_path) |
| return session |
|
|
|
|
| def load_tflite_model(tflite_path: str) -> tf.lite.Interpreter: |
| """Load a TFLite model and return an interpreter.""" |
| print(f"Loading TFLite model from: {tflite_path}") |
| interpreter = tf.lite.Interpreter(model_path=tflite_path) |
| interpreter.allocate_tensors() |
| return interpreter |
|
|
|
|
| def load_tract_model(onnx_path: str) -> Optional[Any]: |
| """Load an ONNX model using tract and return a runnable model.""" |
| if not TRACT_AVAILABLE: |
| print("Tract is not available. Install with: pip install tract") |
| return None |
| print(f"Loading ONNX model with tract from: {onnx_path}") |
| model = tract.onnx().model_for_path(onnx_path).into_optimized().into_runnable() |
| return model |
|
|
|
|
| def get_onnx_model_info(session: ort.InferenceSession) -> Tuple[List, List]: |
| """Get input and output information from ONNX model.""" |
| inputs = session.get_inputs() |
| outputs = session.get_outputs() |
|
|
| print("\nONNX Model Information:") |
| print("Inputs:") |
| for inp in inputs: |
| print(f" - Name: {inp.name}, Shape: {inp.shape}, Type: {inp.type}") |
| print("Outputs:") |
| for out in outputs: |
| print(f" - Name: {out.name}, Shape: {out.shape}, Type: {out.type}") |
|
|
| return inputs, outputs |
|
|
|
|
| def get_tflite_model_info(interpreter: tf.lite.Interpreter) -> Tuple[List, List]: |
| """Get input and output information from TFLite model.""" |
| input_details = interpreter.get_input_details() |
| output_details = interpreter.get_output_details() |
|
|
| print("\nTFLite Model Information:") |
| print("Inputs:") |
| for inp in input_details: |
| print(f" - Name: {inp['name']}, Shape: {inp['shape']}, Type: {inp['dtype']}") |
| print("Outputs:") |
| for out in output_details: |
| print(f" - Name: {out['name']}, Shape: {out['shape']}, Type: {out['dtype']}") |
|
|
| return input_details, output_details |
|
|
|
|
| def generate_random_inputs(onnx_inputs: List, seed: int = 42) -> Dict[str, np.ndarray]: |
| """Generate random inputs based on ONNX model input specs.""" |
| np.random.seed(seed) |
| inputs = {} |
|
|
| print("\nGenerating random inputs:") |
| for inp in onnx_inputs: |
| |
| shape = [] |
| for dim in inp.shape: |
| if isinstance(dim, str) or dim is None or dim < 0: |
| |
| shape.append(1) |
| else: |
| shape.append(dim) |
|
|
| |
| if "float" in inp.type.lower(): |
| data = np.random.randn(*shape).astype(np.float32) |
| elif "int64" in inp.type.lower(): |
| data = np.random.randint(0, 100, size=shape).astype(np.int64) |
| elif "int32" in inp.type.lower(): |
| data = np.random.randint(0, 100, size=shape).astype(np.int32) |
| else: |
| |
| data = np.random.randn(*shape).astype(np.float32) |
|
|
| inputs[inp.name] = data |
| print(f" - {inp.name}: shape={data.shape}, dtype={data.dtype}") |
|
|
| return inputs |
|
|
|
|
| def load_inputs_from_file(input_path: str) -> Dict[str, np.ndarray]: |
| """Load inputs from a numpy file (.npy or .npz).""" |
| print(f"\nLoading inputs from: {input_path}") |
|
|
| if input_path.endswith(".npz"): |
| data = np.load(input_path) |
| inputs = {key: data[key] for key in data.files} |
| elif input_path.endswith(".npy"): |
| data = np.load(input_path) |
| |
| inputs = {"input": data} |
| else: |
| raise ValueError("Input file must be .npy or .npz format") |
|
|
| for name, value in inputs.items(): |
| print(f" - {name}: shape={value.shape}, dtype={value.dtype}") |
|
|
| return inputs |
|
|
|
|
| def run_onnx_model( |
| session: ort.InferenceSession, inputs: Dict[str, np.ndarray] |
| ) -> List[np.ndarray]: |
| """Run inference on ONNX model.""" |
| print("\nRunning ONNX model inference...") |
| outputs = session.run(None, inputs) |
| return outputs |
|
|
|
|
| def run_tflite_model( |
| interpreter: tf.lite.Interpreter, inputs: Dict[str, np.ndarray], input_details: List |
| ) -> List[np.ndarray]: |
| """Run inference on TFLite model.""" |
| print("Running TFLite model inference...") |
|
|
| |
| for i, detail in enumerate(input_details): |
| |
| input_data = None |
| if detail["name"] in inputs: |
| input_data = inputs[detail["name"]] |
| elif len(inputs) == 1: |
| |
| input_data = list(inputs.values())[0] |
| elif i < len(inputs): |
| |
| input_data = list(inputs.values())[i] |
| else: |
| raise ValueError(f"Cannot match input for TFLite input {detail['name']}") |
|
|
| |
| if input_data.dtype != detail["dtype"]: |
| input_data = input_data.astype(detail["dtype"]) |
|
|
| interpreter.set_tensor(detail["index"], input_data) |
|
|
| |
| interpreter.invoke() |
|
|
| |
| output_details = interpreter.get_output_details() |
| outputs = [] |
| for detail in output_details: |
| outputs.append(interpreter.get_tensor(detail["index"])) |
|
|
| return outputs |
|
|
|
|
| def run_tract_model(model: Any, inputs: Dict[str, np.ndarray]) -> List[np.ndarray]: |
| """Run inference on tract model.""" |
| if model is None: |
| return [] |
| print("Running tract model inference...") |
|
|
| |
| input_list = list(inputs.values()) |
|
|
| |
| outputs = model.run(input_list) |
|
|
| |
| result = [] |
| for output in outputs: |
| result.append(output.to_numpy()) |
|
|
| return result |
|
|
|
|
| def benchmark_onnx_model( |
| session: ort.InferenceSession, |
| inputs: Dict[str, np.ndarray], |
| num_runs: int = 100, |
| warmup_runs: int = 10, |
| ) -> Dict[str, float]: |
| """Benchmark ONNX model inference speed.""" |
| print(f"\nBenchmarking ONNX model ({warmup_runs} warmup + {num_runs} test runs)...") |
|
|
| |
| for _ in range(warmup_runs): |
| session.run(None, inputs) |
|
|
| |
| times = [] |
| for _ in range(num_runs): |
| start = time.perf_counter() |
| session.run(None, inputs) |
| end = time.perf_counter() |
| times.append((end - start) * 1000) |
|
|
| return { |
| "mean": np.mean(times), |
| "median": np.median(times), |
| "std": np.std(times), |
| "min": np.min(times), |
| "max": np.max(times), |
| } |
|
|
|
|
| def benchmark_tflite_model( |
| interpreter: tf.lite.Interpreter, |
| inputs: Dict[str, np.ndarray], |
| input_details: List, |
| num_runs: int = 100, |
| warmup_runs: int = 10, |
| ) -> Dict[str, float]: |
| """Benchmark TFLite model inference speed.""" |
| print(f"Benchmarking TFLite model ({warmup_runs} warmup + {num_runs} test runs)...") |
|
|
| |
| def set_inputs(): |
| for i, detail in enumerate(input_details): |
| input_data = None |
| if detail["name"] in inputs: |
| input_data = inputs[detail["name"]] |
| elif len(inputs) == 1: |
| input_data = list(inputs.values())[0] |
| elif i < len(inputs): |
| input_data = list(inputs.values())[i] |
| else: |
| raise ValueError( |
| f"Cannot match input for TFLite input {detail['name']}" |
| ) |
|
|
| if input_data.dtype != detail["dtype"]: |
| input_data = input_data.astype(detail["dtype"]) |
|
|
| interpreter.set_tensor(detail["index"], input_data) |
|
|
| |
| for _ in range(warmup_runs): |
| set_inputs() |
| interpreter.invoke() |
|
|
| |
| times = [] |
| for _ in range(num_runs): |
| set_inputs() |
| start = time.perf_counter() |
| interpreter.invoke() |
| end = time.perf_counter() |
| times.append((end - start) * 1000) |
|
|
| return { |
| "mean": np.mean(times), |
| "median": np.median(times), |
| "std": np.std(times), |
| "min": np.min(times), |
| "max": np.max(times), |
| } |
|
|
|
|
| def benchmark_tract_model( |
| model: Any, |
| inputs: Dict[str, np.ndarray], |
| num_runs: int = 100, |
| warmup_runs: int = 10, |
| ) -> Optional[Dict[str, float]]: |
| """Benchmark tract model inference speed.""" |
| if model is None: |
| return None |
| print(f"Benchmarking tract model ({warmup_runs} warmup + {num_runs} test runs)...") |
|
|
| |
| input_list = list(inputs.values()) |
|
|
| |
| for _ in range(warmup_runs): |
| model.run(input_list) |
|
|
| |
| times = [] |
| for _ in range(num_runs): |
| start = time.perf_counter() |
| model.run(input_list) |
| end = time.perf_counter() |
| times.append((end - start) * 1000) |
|
|
| return { |
| "mean": np.mean(times), |
| "median": np.median(times), |
| "std": np.std(times), |
| "min": np.min(times), |
| "max": np.max(times), |
| } |
|
|
|
|
| def print_benchmark_results( |
| onnx_stats: Dict[str, float], |
| tflite_stats: Dict[str, float], |
| tract_stats: Optional[Dict[str, float]] = None, |
| ) -> None: |
| """Print benchmark comparison results.""" |
| print("\n" + "=" * 80) |
| print("BENCHMARK RESULTS") |
| print("=" * 80) |
|
|
| print("\nONNX Model:") |
| print(f" Mean: {onnx_stats['mean']:.3f} ms") |
| print(f" Median: {onnx_stats['median']:.3f} ms") |
| print(f" Std: {onnx_stats['std']:.3f} ms") |
| print(f" Min: {onnx_stats['min']:.3f} ms") |
| print(f" Max: {onnx_stats['max']:.3f} ms") |
|
|
| print("\nTFLite Model:") |
| print(f" Mean: {tflite_stats['mean']:.3f} ms") |
| print(f" Median: {tflite_stats['median']:.3f} ms") |
| print(f" Std: {tflite_stats['std']:.3f} ms") |
| print(f" Min: {tflite_stats['min']:.3f} ms") |
| print(f" Max: {tflite_stats['max']:.3f} ms") |
|
|
| if tract_stats: |
| print("\nTract Model:") |
| print(f" Mean: {tract_stats['mean']:.3f} ms") |
| print(f" Median: {tract_stats['median']:.3f} ms") |
| print(f" Std: {tract_stats['std']:.3f} ms") |
| print(f" Min: {tract_stats['min']:.3f} ms") |
| print(f" Max: {tract_stats['max']:.3f} ms") |
|
|
| print("\nComparison:") |
| speedup = tflite_stats["mean"] / onnx_stats["mean"] |
| if speedup > 1: |
| print(f" ONNX Runtime is {speedup:.2f}x faster than TFLite") |
| else: |
| print(f" TFLite is {1 / speedup:.2f}x faster than ONNX Runtime") |
| print(f" Difference: {abs(onnx_stats['mean'] - tflite_stats['mean']):.3f} ms") |
|
|
| if tract_stats: |
| speedup_tract = tflite_stats["mean"] / tract_stats["mean"] |
| if speedup_tract > 1: |
| print(f" Tract is {speedup_tract:.2f}x faster than TFLite") |
| else: |
| print(f" TFLite is {1 / speedup_tract:.2f}x faster than Tract") |
| print(f" Difference: {abs(tract_stats['mean'] - tflite_stats['mean']):.3f} ms") |
|
|
| speedup_ort = onnx_stats["mean"] / tract_stats["mean"] |
| if speedup_ort > 1: |
| print(f" Tract is {speedup_ort:.2f}x faster than ONNX Runtime") |
| else: |
| print(f" ONNX Runtime is {1 / speedup_ort:.2f}x faster than Tract") |
| print(f" Difference: {abs(tract_stats['mean'] - onnx_stats['mean']):.3f} ms") |
|
|
| print("=" * 80) |
|
|
|
|
| def compare_outputs( |
| onnx_outputs: List[np.ndarray], |
| tflite_outputs: List[np.ndarray], |
| tract_outputs: Optional[List[np.ndarray]] = None, |
| rtol: float = 1e-5, |
| atol: float = 1e-5, |
| ) -> bool: |
| """Compare outputs from ONNX, TFLite, and optionally Tract models.""" |
| print("\n" + "=" * 80) |
| print("COMPARISON RESULTS") |
| print("=" * 80) |
|
|
| if len(onnx_outputs) != len(tflite_outputs): |
| print( |
| f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, TFLite={len(tflite_outputs)}" |
| ) |
| return False |
|
|
| if tract_outputs and len(onnx_outputs) != len(tract_outputs): |
| print( |
| f"❌ Number of outputs differs: ONNX={len(onnx_outputs)}, Tract={len(tract_outputs)}" |
| ) |
| return False |
|
|
| all_match = True |
| for i, (onnx_out, tflite_out) in enumerate(zip(onnx_outputs, tflite_outputs)): |
| tract_out = tract_outputs[i] if tract_outputs else None |
|
|
| print(f"\nOutput {i}:") |
| print(f" ONNX Runtime shape: {onnx_out.shape}, dtype: {onnx_out.dtype}") |
| print(f" TFLite shape: {tflite_out.shape}, dtype: {tflite_out.dtype}") |
| if tract_out is not None: |
| print(f" Tract shape: {tract_out.shape}, dtype: {tract_out.dtype}") |
|
|
| if onnx_out.shape != tflite_out.shape: |
| print(" ❌ Shape mismatch between ONNX and TFLite!") |
| all_match = False |
| continue |
|
|
| if tract_out is not None and onnx_out.shape != tract_out.shape: |
| print(" ❌ Shape mismatch between ONNX and Tract!") |
| all_match = False |
| continue |
|
|
| |
| if onnx_out.dtype != tflite_out.dtype: |
| print(" ⚠️ Different dtypes, converting to float32 for comparison") |
| onnx_out = onnx_out.astype(np.float32) |
| tflite_out = tflite_out.astype(np.float32) |
|
|
| if tract_out is not None and onnx_out.dtype != tract_out.dtype: |
| tract_out = tract_out.astype(np.float32) |
|
|
| |
| print("\n ONNX Runtime vs TFLite:") |
| diff = np.abs(onnx_out - tflite_out) |
| max_diff = np.max(diff) |
| mean_diff = np.mean(diff) |
| is_close = np.allclose(onnx_out, tflite_out, rtol=rtol, atol=atol) |
|
|
| print(f" Max difference: {max_diff:.10f}") |
| print(f" Mean difference: {mean_diff:.10f}") |
| print(f" Relative tolerance: {rtol}") |
| print(f" Absolute tolerance: {atol}") |
|
|
| if is_close: |
| print(" ✅ Outputs match within tolerance") |
| else: |
| print(" ❌ Outputs do NOT match within tolerance") |
| all_match = False |
|
|
| |
| print("\n Sample values (first 5 elements):") |
| flat_onnx = onnx_out.flatten()[:5] |
| flat_tflite = tflite_out.flatten()[:5] |
| for j, (o, t) in enumerate(zip(flat_onnx, flat_tflite)): |
| print( |
| f" [{j}] ONNX: {o:.10f}, TFLite: {t:.10f}, Diff: {abs(o - t):.10f}" |
| ) |
|
|
| |
| if tract_out is not None: |
| print("\n ONNX Runtime vs Tract:") |
| diff_tract = np.abs(onnx_out - tract_out) |
| max_diff_tract = np.max(diff_tract) |
| mean_diff_tract = np.mean(diff_tract) |
| is_close_tract = np.allclose(onnx_out, tract_out, rtol=rtol, atol=atol) |
|
|
| print(f" Max difference: {max_diff_tract:.10f}") |
| print(f" Mean difference: {mean_diff_tract:.10f}") |
|
|
| if is_close_tract: |
| print(" ✅ Outputs match within tolerance") |
| else: |
| print(" ❌ Outputs do NOT match within tolerance") |
| all_match = False |
|
|
| |
| print("\n Sample values (first 5 elements):") |
| flat_onnx_tract = onnx_out.flatten()[:5] |
| flat_tract = tract_out.flatten()[:5] |
| for j, (o, tr) in enumerate(zip(flat_onnx_tract, flat_tract)): |
| print( |
| f" [{j}] ONNX: {o:.10f}, Tract: {tr:.10f}, Diff: {abs(o - tr):.10f}" |
| ) |
|
|
| |
| print("\n TFLite vs Tract:") |
| diff_tflite_tract = np.abs(tflite_out - tract_out) |
| max_diff_tflite_tract = np.max(diff_tflite_tract) |
| mean_diff_tflite_tract = np.mean(diff_tflite_tract) |
| is_close_tflite_tract = np.allclose( |
| tflite_out, tract_out, rtol=rtol, atol=atol |
| ) |
|
|
| print(f" Max difference: {max_diff_tflite_tract:.10f}") |
| print(f" Mean difference: {mean_diff_tflite_tract:.10f}") |
|
|
| if is_close_tflite_tract: |
| print(" ✅ Outputs match within tolerance") |
| else: |
| print(" ❌ Outputs do NOT match within tolerance") |
| all_match = False |
|
|
| print("\n" + "=" * 80) |
| if all_match: |
| print("✅ ALL OUTPUTS MATCH!") |
| else: |
| print("❌ SOME OUTPUTS DO NOT MATCH") |
| print("=" * 80) |
|
|
| return all_match |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Compare ONNX and TFLite model outputs", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Compare with random inputs |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite |
| |
| # Compare with custom inputs from file |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --input input.npz |
| |
| # Compare with custom tolerances |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --rtol 1e-3 --atol 1e-3 |
| |
| # Save outputs for inspection |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --save-outputs |
| |
| # Benchmark execution speed |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark |
| |
| # Benchmark with custom number of runs |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --benchmark --num-runs 200 --warmup-runs 20 |
| |
| # Compare with tract runtime as well |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract |
| |
| # Benchmark all three runtimes |
| python compare_onnx_tflite.py --onnx model.onnx --tflite model.tflite --use-tract --benchmark |
| """, |
| ) |
|
|
| parser.add_argument("--onnx", required=True, help="Path to ONNX model") |
| parser.add_argument("--tflite", required=True, help="Path to TFLite model") |
| parser.add_argument("--input", help="Path to input file (.npy or .npz)") |
| parser.add_argument( |
| "--rtol", type=float, default=1e-5, help="Relative tolerance (default: 1e-5)" |
| ) |
| parser.add_argument( |
| "--atol", type=float, default=1e-5, help="Absolute tolerance (default: 1e-5)" |
| ) |
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=42, |
| help="Random seed for input generation (default: 42)", |
| ) |
| parser.add_argument( |
| "--save-outputs", action="store_true", help="Save outputs to files" |
| ) |
| parser.add_argument( |
| "--benchmark", |
| action="store_true", |
| help="Benchmark execution speed of both models", |
| ) |
| parser.add_argument( |
| "--num-runs", |
| type=int, |
| default=100, |
| help="Number of benchmark runs (default: 100)", |
| ) |
| parser.add_argument( |
| "--warmup-runs", |
| type=int, |
| default=10, |
| help="Number of warmup runs (default: 10)", |
| ) |
| parser.add_argument( |
| "--use-tract", action="store_true", help="Also test with tract ONNX runtime" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| onnx_session = load_onnx_model(args.onnx) |
| tflite_interpreter = load_tflite_model(args.tflite) |
|
|
| |
| tract_model = None |
| if args.use_tract: |
| if not TRACT_AVAILABLE: |
| print( |
| "\n⚠️ Warning: Tract is not installed. Install with: pip install tract" |
| ) |
| print("Continuing without tract comparison...\n") |
| else: |
| tract_model = load_tract_model(args.onnx) |
|
|
| |
| onnx_inputs, onnx_outputs = get_onnx_model_info(onnx_session) |
| tflite_input_details, tflite_output_details = get_tflite_model_info( |
| tflite_interpreter |
| ) |
|
|
| |
| if args.input: |
| inputs = load_inputs_from_file(args.input) |
| else: |
| inputs = generate_random_inputs(onnx_inputs, seed=args.seed) |
|
|
| |
| onnx_results = run_onnx_model(onnx_session, inputs) |
| tflite_results = run_tflite_model(tflite_interpreter, inputs, tflite_input_details) |
| tract_results = None |
| if tract_model: |
| tract_results = run_tract_model(tract_model, inputs) |
|
|
| |
| if args.save_outputs: |
| print("\nSaving outputs...") |
| np.savez("onnx_outputs.npz", *onnx_results) |
| np.savez("tflite_outputs.npz", *tflite_results) |
| print(" - onnx_outputs.npz") |
| print(" - tflite_outputs.npz") |
| if tract_results: |
| np.savez("tract_outputs.npz", *tract_results) |
| print(" - tract_outputs.npz") |
|
|
| |
| match = compare_outputs( |
| onnx_results, tflite_results, tract_results, rtol=args.rtol, atol=args.atol |
| ) |
|
|
| |
| if args.benchmark: |
| onnx_stats = benchmark_onnx_model( |
| onnx_session, inputs, args.num_runs, args.warmup_runs |
| ) |
| tflite_stats = benchmark_tflite_model( |
| tflite_interpreter, |
| inputs, |
| tflite_input_details, |
| args.num_runs, |
| args.warmup_runs, |
| ) |
| tract_stats = None |
| if tract_model: |
| tract_stats = benchmark_tract_model( |
| tract_model, inputs, args.num_runs, args.warmup_runs |
| ) |
| print_benchmark_results(onnx_stats, tflite_stats, tract_stats) |
|
|
| |
| return 0 if match else 1 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|