Porting another Cactus-trained model to ONNX
The scripts in this repo were built around the published Cactus-Compute/needle checkpoint, but they work as-is for any model trained with the upstream Cactus pipeline. If you've finetuned Needle (or trained a new Simple-Attention-Network variant) and want a browser-ready ONNX export, this is the recipe.
The needle_torch/ package is parametric on TransformerConfig β it doesn't assume the production 26M dims. convert_weights.py reads the config straight out of your checkpoint's payload and writes it next to the .pt file. export_onnx.py reads that config back. So as long as your finetuned model uses the same architecture (Simple Attention Network: encoder-decoder, GQA, RoPE, ZCRMSNorm, optionally no FFN), the only thing that changes is the source repo/filename.
Prerequisites
- A Cactus-format checkpoint published on HF Hub. The checkpoint must be a serialized dict with the shape
{"config": {...}, "params": <Flax pytree>}β this is whatneedle/training/{train,pretrain}.pysaves. - A modern Python (β₯ 3.11) with
uvinstalled. - ~3 GB of disk for the full pipeline (Flax + PyTorch + ONNX runtimes).
Step-by-step
1. Clone Cactus and this pipeline
git clone https://github.com/cactus-compute/needle.git external/needle
# Plus this repo's `export/` directory and `needle_torch/` package
2. Set up the env
cd export
uv sync
3. Convert your checkpoint to a PyTorch state_dict
uv run python convert_weights.py \
--ckpt-repo your-username/your-finetune \
--ckpt-file weights.pkl
This downloads the checkpoint, walks the Flax pytree, copies tensors into a NeedleModel (parametric on the embedded config), and saves:
artifacts/needle_torch.ptβ PyTorch state_dictartifacts/needle_torch.config.jsonβ config dict (used byexport_onnx.py)
4. (Strongly recommended) Verify Flax β PyTorch parity
uv run python verify_port_parity.py
Should print port parity OK (< 1e-3). If parity fails, the conversion has a bug β fix it before exporting to ONNX. Common culprits:
- ZCRMSNorm formula: must be
(1 + Ξ³) Β· x / RMS(x)with Ξ³ init zero, NOT the standardΞ³ Β· x / RMS(x). - GQA broadcast:
k.repeat_interleave(repeats, dim=heads)before attention, matching Flax'sjnp.repeat(k, repeats, axis=heads). - Q/K-norm position: applied before RoPE.
- Linear weight transposition: Flax stores
(in, out), PyTorch is(out, in). The script handles this on copy. - Tied embedding: appears under three keys in PyTorch state_dict (
embedding.weight,encoder.embedding.weight,decoder.embedding.weight); all three must be set to the same tensor.
5. Export to ONNX
uv run python export_onnx.py
Produces:
artifacts/encoder.onnxβ encoder graph (input_ids β encoder_out)artifacts/decoder_step.onnxβ one decoder step with KV-cache I/O (decoder_input_ids, encoder_out, past_self_kv β logits, present_self_kv)
Both files are self-contained (no external .data sidecar). The decoder is exported as a single step so the browser-side runs it in a loop with streaming output and a growing KV cache, rather than tracing a full Loop op into the graph.
6. Verify PyTorch β ONNX parity (and end-to-end)
uv run python verify_parity.py \
--ckpt-repo your-username/your-finetune \
--ckpt-file weights.pkl
Runs three checks:
- PyTorch encoder vs ONNX encoder, max-abs-diff < 1e-3
- PyTorch decoder step vs ONNX decoder step, max-abs-diff < 1e-3
- End-to-end: Cactus's native
generate(constrained=False)vs a hand-rolled (encoder + decoder-step) ONNX loop β must produce byte-identical token sequences.
If (1) or (2) pass but (3) fails, the bug is almost certainly in the multi-step KV-cache handling. The most common cause is re-applying RoPE to the concatenated (past_k + new_k) tensor instead of just new_k β this double-rotates cached keys on every step. The fix lives in needle_torch/layers.py's MultiHeadAttention.forward.
7. Dump the SentencePiece tokenizer for browser use
uv run python dump_tokenizer.py
Copies needle.model and special-token IDs to where the browser can fetch them, plus emits parity goldens for the TS tokenizer port.
8. Push to HF Hub
uv run python upload_to_hf.py --repo your-username/your-finetune-onnx
Uploads:
encoder.onnx,decoder_step.onnxneedle.model,tokenizer-specials.json- A model-card README with provenance and the parity numbers you measured
- The pipeline scripts themselves (so downstream finetuners can repeat the recipe)
Plug it into the browser
The browser app at onnx-community/needle-playground fetches from onnx-community/needle-onnx by default. To swap in your finetune, edit web/src/config.ts:
export const MODEL_BASE_URL = import.meta.env.PROD
? 'https://huggingface.co/your-username/your-finetune-onnx/resolve/main'
: '/models-dev';
Then npm run build and deploy_space.py --repo your-username/your-finetune-playground to ship.
What's not supported
- Architecture changes beyond
TransformerConfig's field set (e.g. swapping ZCRMSNorm for LayerNorm, adding cross-layer parameter sharing not present in Cactus, etc.) β you'd need to editneedle_torch/layers.pyandmodel.py. Updatenotes/needle-internals.mdfirst so the change is documented. - Quantization. The export is fp32. INT8/INT4 quantization is a separate post-processing step (
onnxruntime.quantization); plumbing not included. - Speech inputs (
enable_speech=True). The port ignores the speech encoder branch.