fix: Include all Dense projection layers in ONNX export (output dim 64)

#2

Problem

The current ONNX exports (model.onnx and model_int8.onnx) are missing the final Dense projection layers from PyLate's modules_list pipeline. This causes the ONNX models to output the intermediate hidden dimension instead of the correct final embedding dimension.

mxbai-edge-colbert-v0-32m:

  • Current ONNX output: 768 dimensions (intermediate)
  • Expected output: 64 dimensions (after all Dense projections)

Root Cause

The ONNX export was done using the standard HuggingFace Transformers export pipeline, which only picks up the transformer backbone and possibly the first linear head. It does not include the additional Dense modules stored in the separate 1_Dense/, 2_Dense/, directories that are part of PyLate's modules_list architecture.

Fix

Re-exported by wrapping the full pipeline (Transformer + all Dense layers) into a single module before ONNX export:

  • model.onnx: fp32, opset 17, output dim = 64
  • model_int8.onnx: int8 dynamic quantization with projection layers kept in fp32 (they are small and precision-sensitive)

Both files have been verified for:

  • Correct output dimensions
  • ONNX model validity (onnx.checker)
  • Numerical consistency with PyTorch (fp32)

Benchmark Results: Fixed ONNX vs Original PyLate Checkpoint

Tested on 6 NanoBEIR datasets using brute-force MaxSim scoring. The fixed ONNX fp32 model produces identical results to the original PyLate checkpoint (embedding max diff ~2-3e-07).

nDCG@10

Dataset PyLate (original) ONNX fp32 ONNX int8 fp32 diff int8 diff
FiQA 0.5554 0.5554 0.5597 +0.0000 +0.0043
SciFact 0.7983 0.7983 0.7925 +0.0000 -0.0059
NFCorpus 0.3673 0.3673 0.3423 +0.0000 -0.0250
SCIDOCS 0.4007 0.4007 0.3739 +0.0000 -0.0268
HotpotQA 0.8740 0.8740 0.8690 +0.0000 -0.0050
NQ 0.7474 0.7474 0.6517 +0.0000 -0.0957

Key findings:

  • fp32 ONNX matches PyLate exactly on all 6 datasets (nDCG@10 diff = 0.0000)
  • int8 ONNX shows expected small degradation from dynamic quantization (projection layers kept in fp32)
  • Embedding max diff between fp32 ONNX and PyLate: ~2-3e-07 (numerically identical)
bclavie changed pull request status to merged

Sign up or log in to comment