| |
|
|
| import os |
| import time |
| import json |
| import numpy as np |
| import tensorflow as tf |
| import onnx |
| import onnxruntime as ort |
|
|
| from src.utils import get_logger, load_config |
| from data_loader import get_data_generators |
|
|
| logger = get_logger("export_onnx") |
|
|
|
|
| |
| |
| |
|
|
| QUANTIZATION_AVAILABLE = True |
| QUANT_IMPORT_ERROR = None |
|
|
| try: |
| from onnxruntime.quantization import ( |
| quantize_dynamic, |
| quantize_static, |
| QuantType, |
| CalibrationDataReader, |
| QuantFormat, |
| ) |
| except Exception as e: |
| QUANTIZATION_AVAILABLE = False |
| QUANT_IMPORT_ERROR = str(e) |
|
|
| class CalibrationDataReader: |
| """Fallback placeholder when quantization imports are unavailable.""" |
| pass |
|
|
|
|
| |
| |
| |
|
|
| def export_to_onnx(model, onnx_path: str, image_size: tuple = (150, 150)): |
| import tf2onnx |
| import tf2onnx.convert |
|
|
| os.makedirs(os.path.dirname(onnx_path), exist_ok=True) |
|
|
| input_signature = [ |
| tf.TensorSpec( |
| shape=(None, *image_size, 3), |
| dtype=tf.float32, |
| name="input", |
| ) |
| ] |
|
|
| logger.info(f"Exporting model to ONNX β {onnx_path}") |
| tf2onnx.convert.from_keras( |
| model, |
| input_signature=input_signature, |
| opset=13, |
| output_path=onnx_path, |
| ) |
|
|
| onnx_model = onnx.load(onnx_path) |
| onnx.checker.check_model(onnx_model) |
|
|
| size_mb = os.path.getsize(onnx_path) / (1024 * 1024) |
| logger.info(f"ONNX export successful β size: {size_mb:.2f} MB") |
| return onnx_path |
|
|
|
|
| |
| |
| |
|
|
| def dynamic_quantize(onnx_path: str, output_path: str): |
| if not QUANTIZATION_AVAILABLE: |
| logger.warning(f"Dynamic quantization skipped: {QUANT_IMPORT_ERROR}") |
| return None |
|
|
| logger.info(f"Applying Dynamic Quantization β {output_path}") |
| quantize_dynamic( |
| model_input=onnx_path, |
| model_output=output_path, |
| weight_type=QuantType.QInt8, |
| ) |
|
|
| size_mb = os.path.getsize(output_path) / (1024 * 1024) |
| logger.info(f"Dynamic quantized model β size: {size_mb:.2f} MB") |
| return output_path |
|
|
|
|
| |
| |
| |
|
|
| class MRICalibrationReader(CalibrationDataReader): |
| """Feeds calibration batches to the static quantizer.""" |
|
|
| def __init__(self, data_generator, n_batches: int = 10): |
| self.data = [] |
| self.index = 0 |
|
|
| for i, (batch_x, _) in enumerate(data_generator): |
| if i >= n_batches: |
| break |
| self.data.append(batch_x.astype(np.float32)) |
|
|
| logger.info(f"Calibration reader: {len(self.data)} batches loaded") |
|
|
| def get_next(self): |
| if self.index >= len(self.data): |
| return None |
|
|
| batch = {"input": self.data[self.index]} |
| self.index += 1 |
| return batch |
|
|
|
|
| def static_quantize(onnx_path: str, output_path: str, train_data, n_batches: int = 10): |
| if not QUANTIZATION_AVAILABLE: |
| logger.warning(f"Static quantization skipped: {QUANT_IMPORT_ERROR}") |
| return None |
|
|
| logger.info(f"Applying Static Quantization β {output_path}") |
| reader = MRICalibrationReader(train_data, n_batches=n_batches) |
|
|
| quantize_static( |
| model_input=onnx_path, |
| model_output=output_path, |
| calibration_data_reader=reader, |
| quant_format=QuantFormat.QDQ, |
| activation_type=QuantType.QInt8, |
| weight_type=QuantType.QInt8, |
| ) |
|
|
| size_mb = os.path.getsize(output_path) / (1024 * 1024) |
| logger.info(f"Static quantized model β size: {size_mb:.2f} MB") |
| return output_path |
|
|
|
|
| |
| |
| |
|
|
| def onnx_predict(onnx_path: str, img_array: np.ndarray) -> np.ndarray: |
| sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
| input_name = sess.get_inputs()[0].name |
| output_name = sess.get_outputs()[0].name |
| return sess.run([output_name], {input_name: img_array.astype(np.float32)})[0] |
|
|
|
|
| |
| |
| |
|
|
| def benchmark_models(tf_model, paths: dict, test_data, n_samples: int = 200) -> dict: |
| logger.info("\nBenchmarking model formats ...") |
|
|
| X_all, y_all = [], [] |
| total = 0 |
|
|
| for batch_x, batch_y in test_data: |
| X_all.append(batch_x) |
| y_all.append(batch_y) |
| total += len(batch_x) |
| if total >= n_samples: |
| break |
|
|
| X = np.concatenate(X_all, axis=0)[:n_samples].astype(np.float32) |
| y = np.concatenate(y_all, axis=0)[:n_samples] |
| y_true = np.argmax(y, axis=1) |
|
|
| results = {} |
|
|
| |
| t0 = time.time() |
| preds = tf_model.predict(X, verbose=0) |
| tf_ms = (time.time() - t0) * 1000 / len(X) |
| tf_acc = (np.argmax(preds, axis=1) == y_true).mean() |
|
|
| results["TensorFlow (FP32)"] = { |
| "latency_ms": float(tf_ms), |
| "accuracy": float(tf_acc), |
| "size_mb": None, |
| } |
| logger.info(f"TensorFlow (FP32) | acc={tf_acc:.4f} | {tf_ms:.2f} ms/sample") |
|
|
| |
| for name, path in paths.items(): |
| if path is None: |
| logger.warning(f"Skipping {name} β path is None") |
| continue |
| if not os.path.exists(path): |
| logger.warning(f"Skipping {name} β file not found: {path}") |
| continue |
|
|
| try: |
| t0 = time.time() |
| preds = onnx_predict(path, X) |
| ms = (time.time() - t0) * 1000 / len(X) |
| acc = (np.argmax(preds, axis=1) == y_true).mean() |
| size = os.path.getsize(path) / (1024 * 1024) |
|
|
| results[name] = { |
| "latency_ms": float(ms), |
| "accuracy": float(acc), |
| "size_mb": float(size), |
| } |
| logger.info(f"{name:<24} | acc={acc:.4f} | {ms:.2f} ms/sample | {size:.2f} MB") |
|
|
| except Exception as e: |
| logger.warning(f"Skipping {name} due to runtime error: {e}") |
|
|
| return results |
|
|
|
|
| def print_benchmark_table(results: dict): |
| print("\n" + "=" * 70) |
| print(f"{'Format':<26} {'Accuracy':>10} {'Latency(ms)':>13} {'Size(MB)':>12}") |
| print("=" * 70) |
| for name, r in results.items(): |
| size = f"{r['size_mb']:.2f}" if r["size_mb"] is not None else "β" |
| print(f"{name:<26} {r['accuracy']:>10.4f} {r['latency_ms']:>13.2f} {size:>12}") |
| print("=" * 70) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| cfg = load_config("config.yaml") |
| image_size = tuple(cfg["data"]["image_size"]) |
| onnx_dir = cfg["models"]["onnx_dir"] |
| save_dir = cfg["models"]["save_dir"] |
|
|
| os.makedirs(onnx_dir, exist_ok=True) |
|
|
| train_data, val_data, test_data = get_data_generators(cfg) |
|
|
| |
| model_path = os.path.join(save_dir, "ft_best.h5") |
| logger.info(f"Loading model from {model_path}") |
| model = tf.keras.models.load_model(model_path, compile=False) |
|
|
| |
| onnx_fp32_path = os.path.join(onnx_dir, "model_fp32.onnx") |
| onnx_dynamic_path = os.path.join(onnx_dir, "model_dynamic_int8.onnx") |
| onnx_static_path = os.path.join(onnx_dir, "model_static_int8.onnx") |
|
|
| |
| export_to_onnx(model, onnx_fp32_path, image_size) |
|
|
| |
| dynamic_path = dynamic_quantize(onnx_fp32_path, onnx_dynamic_path) |
| static_path = static_quantize(onnx_fp32_path, onnx_static_path, train_data, n_batches=50) |
|
|
| |
| paths = { |
| "ONNX FP32": onnx_fp32_path, |
| "ONNX Dynamic INT8": dynamic_path, |
| "ONNX Static INT8": static_path, |
| } |
|
|
| results = benchmark_models(model, paths, test_data) |
| print_benchmark_table(results) |
|
|
| |
| bench_path = os.path.join(onnx_dir, "benchmark_results.json") |
| with open(bench_path, "w") as f: |
| json.dump(results, f, indent=4) |
|
|
| logger.info(f"Benchmark results saved β {bench_path}") |
|
|
| if QUANTIZATION_AVAILABLE: |
| logger.info("ONNX export complete. Quantization attempted.") |
| else: |
| logger.warning(f"ONNX export complete. Quantization skipped: {QUANT_IMPORT_ERROR}") |