needle-onnx / PORTING.md
shreyask's picture
Upload PORTING.md with huggingface_hub
cd8d04f verified

Porting another Cactus-trained model to ONNX

The scripts in this repo were built around the published Cactus-Compute/needle checkpoint, but they work as-is for any model trained with the upstream Cactus pipeline. If you've finetuned Needle (or trained a new Simple-Attention-Network variant) and want a browser-ready ONNX export, this is the recipe.

The needle_torch/ package is parametric on TransformerConfig β€” it doesn't assume the production 26M dims. convert_weights.py reads the config straight out of your checkpoint's payload and writes it next to the .pt file. export_onnx.py reads that config back. So as long as your finetuned model uses the same architecture (Simple Attention Network: encoder-decoder, GQA, RoPE, ZCRMSNorm, optionally no FFN), the only thing that changes is the source repo/filename.

Prerequisites

  • A Cactus-format checkpoint published on HF Hub. The checkpoint must be a serialized dict with the shape {"config": {...}, "params": <Flax pytree>} β€” this is what needle/training/{train,pretrain}.py saves.
  • A modern Python (β‰₯ 3.11) with uv installed.
  • ~3 GB of disk for the full pipeline (Flax + PyTorch + ONNX runtimes).

Step-by-step

1. Clone Cactus and this pipeline

git clone https://github.com/cactus-compute/needle.git external/needle
# Plus this repo's `export/` directory and `needle_torch/` package

2. Set up the env

cd export
uv sync

3. Convert your checkpoint to a PyTorch state_dict

uv run python convert_weights.py \
    --ckpt-repo your-username/your-finetune \
    --ckpt-file weights.pkl

This downloads the checkpoint, walks the Flax pytree, copies tensors into a NeedleModel (parametric on the embedded config), and saves:

  • artifacts/needle_torch.pt β€” PyTorch state_dict
  • artifacts/needle_torch.config.json β€” config dict (used by export_onnx.py)

4. (Strongly recommended) Verify Flax ↔ PyTorch parity

uv run python verify_port_parity.py

Should print port parity OK (< 1e-3). If parity fails, the conversion has a bug β€” fix it before exporting to ONNX. Common culprits:

  • ZCRMSNorm formula: must be (1 + Ξ³) Β· x / RMS(x) with Ξ³ init zero, NOT the standard Ξ³ Β· x / RMS(x).
  • GQA broadcast: k.repeat_interleave(repeats, dim=heads) before attention, matching Flax's jnp.repeat(k, repeats, axis=heads).
  • Q/K-norm position: applied before RoPE.
  • Linear weight transposition: Flax stores (in, out), PyTorch is (out, in). The script handles this on copy.
  • Tied embedding: appears under three keys in PyTorch state_dict (embedding.weight, encoder.embedding.weight, decoder.embedding.weight); all three must be set to the same tensor.

5. Export to ONNX

uv run python export_onnx.py

Produces:

  • artifacts/encoder.onnx β€” encoder graph (input_ids β†’ encoder_out)
  • artifacts/decoder_step.onnx β€” one decoder step with KV-cache I/O (decoder_input_ids, encoder_out, past_self_kv β†’ logits, present_self_kv)

Both files are self-contained (no external .data sidecar). The decoder is exported as a single step so the browser-side runs it in a loop with streaming output and a growing KV cache, rather than tracing a full Loop op into the graph.

6. Verify PyTorch ↔ ONNX parity (and end-to-end)

uv run python verify_parity.py \
    --ckpt-repo your-username/your-finetune \
    --ckpt-file weights.pkl

Runs three checks:

  1. PyTorch encoder vs ONNX encoder, max-abs-diff < 1e-3
  2. PyTorch decoder step vs ONNX decoder step, max-abs-diff < 1e-3
  3. End-to-end: Cactus's native generate(constrained=False) vs a hand-rolled (encoder + decoder-step) ONNX loop β€” must produce byte-identical token sequences.

If (1) or (2) pass but (3) fails, the bug is almost certainly in the multi-step KV-cache handling. The most common cause is re-applying RoPE to the concatenated (past_k + new_k) tensor instead of just new_k β€” this double-rotates cached keys on every step. The fix lives in needle_torch/layers.py's MultiHeadAttention.forward.

7. Dump the SentencePiece tokenizer for browser use

uv run python dump_tokenizer.py

Copies needle.model and special-token IDs to where the browser can fetch them, plus emits parity goldens for the TS tokenizer port.

8. Push to HF Hub

uv run python upload_to_hf.py --repo your-username/your-finetune-onnx

Uploads:

  • encoder.onnx, decoder_step.onnx
  • needle.model, tokenizer-specials.json
  • A model-card README with provenance and the parity numbers you measured
  • The pipeline scripts themselves (so downstream finetuners can repeat the recipe)

Plug it into the browser

The browser app at onnx-community/needle-playground fetches from onnx-community/needle-onnx by default. To swap in your finetune, edit web/src/config.ts:

export const MODEL_BASE_URL = import.meta.env.PROD
  ? 'https://huggingface.co/your-username/your-finetune-onnx/resolve/main'
  : '/models-dev';

Then npm run build and deploy_space.py --repo your-username/your-finetune-playground to ship.

What's not supported

  • Architecture changes beyond TransformerConfig's field set (e.g. swapping ZCRMSNorm for LayerNorm, adding cross-layer parameter sharing not present in Cactus, etc.) β€” you'd need to edit needle_torch/layers.py and model.py. Update notes/needle-internals.md first so the change is documented.
  • Quantization. The export is fp32. INT8/INT4 quantization is a separate post-processing step (onnxruntime.quantization); plumbing not included.
  • Speech inputs (enable_speech=True). The port ignores the speech encoder branch.