Upload PORTING.md with huggingface_hub
Browse files- PORTING.md +122 -0
PORTING.md
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Porting another Cactus-trained model to ONNX
|
| 2 |
+
|
| 3 |
+
The scripts in this repo were built around the published [Cactus-Compute/needle](https://huggingface.co/Cactus-Compute/needle) checkpoint, but they work as-is for **any** model trained with the upstream [Cactus pipeline](https://github.com/cactus-compute/needle). If you've finetuned Needle (or trained a new Simple-Attention-Network variant) and want a browser-ready ONNX export, this is the recipe.
|
| 4 |
+
|
| 5 |
+
The `needle_torch/` package is parametric on `TransformerConfig` β it doesn't assume the production 26M dims. `convert_weights.py` reads the config straight out of your checkpoint's payload and writes it next to the `.pt` file. `export_onnx.py` reads that config back. So as long as your finetuned model uses the same architecture (Simple Attention Network: encoder-decoder, GQA, RoPE, ZCRMSNorm, optionally no FFN), the only thing that changes is the source repo/filename.
|
| 6 |
+
|
| 7 |
+
## Prerequisites
|
| 8 |
+
|
| 9 |
+
- A Cactus-format checkpoint published on HF Hub. The checkpoint must be a serialized dict with the shape `{"config": {...}, "params": <Flax pytree>}` β this is what `needle/training/{train,pretrain}.py` saves.
|
| 10 |
+
- A modern Python (β₯ 3.11) with `uv` installed.
|
| 11 |
+
- ~3 GB of disk for the full pipeline (Flax + PyTorch + ONNX runtimes).
|
| 12 |
+
|
| 13 |
+
## Step-by-step
|
| 14 |
+
|
| 15 |
+
### 1. Clone Cactus and this pipeline
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
git clone https://github.com/cactus-compute/needle.git external/needle
|
| 19 |
+
# Plus this repo's `export/` directory and `needle_torch/` package
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### 2. Set up the env
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
cd export
|
| 26 |
+
uv sync
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### 3. Convert your checkpoint to a PyTorch `state_dict`
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
uv run python convert_weights.py \
|
| 33 |
+
--ckpt-repo your-username/your-finetune \
|
| 34 |
+
--ckpt-file weights.pkl
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
This downloads the checkpoint, walks the Flax pytree, copies tensors into a `NeedleModel` (parametric on the embedded config), and saves:
|
| 38 |
+
|
| 39 |
+
- `artifacts/needle_torch.pt` β PyTorch state_dict
|
| 40 |
+
- `artifacts/needle_torch.config.json` β config dict (used by `export_onnx.py`)
|
| 41 |
+
|
| 42 |
+
### 4. (Strongly recommended) Verify Flax β PyTorch parity
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
uv run python verify_port_parity.py
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Should print `port parity OK (< 1e-3)`. If parity fails, the conversion has a bug β fix it before exporting to ONNX. Common culprits:
|
| 49 |
+
|
| 50 |
+
- ZCRMSNorm formula: must be `(1 + Ξ³) Β· x / RMS(x)` with Ξ³ init zero, NOT the standard `Ξ³ Β· x / RMS(x)`.
|
| 51 |
+
- GQA broadcast: `k.repeat_interleave(repeats, dim=heads)` *before* attention, matching Flax's `jnp.repeat(k, repeats, axis=heads)`.
|
| 52 |
+
- Q/K-norm position: applied *before* RoPE.
|
| 53 |
+
- Linear weight transposition: Flax stores `(in, out)`, PyTorch is `(out, in)`. The script handles this on copy.
|
| 54 |
+
- Tied embedding: appears under three keys in PyTorch state_dict (`embedding.weight`, `encoder.embedding.weight`, `decoder.embedding.weight`); all three must be set to the same tensor.
|
| 55 |
+
|
| 56 |
+
### 5. Export to ONNX
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
uv run python export_onnx.py
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Produces:
|
| 63 |
+
|
| 64 |
+
- `artifacts/encoder.onnx` β encoder graph (input_ids β encoder_out)
|
| 65 |
+
- `artifacts/decoder_step.onnx` β one decoder step with KV-cache I/O (decoder_input_ids, encoder_out, past_self_kv β logits, present_self_kv)
|
| 66 |
+
|
| 67 |
+
Both files are self-contained (no external `.data` sidecar). The decoder is exported as a *single step* so the browser-side runs it in a loop with streaming output and a growing KV cache, rather than tracing a full `Loop` op into the graph.
|
| 68 |
+
|
| 69 |
+
### 6. Verify PyTorch β ONNX parity (and end-to-end)
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
uv run python verify_parity.py \
|
| 73 |
+
--ckpt-repo your-username/your-finetune \
|
| 74 |
+
--ckpt-file weights.pkl
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Runs three checks:
|
| 78 |
+
|
| 79 |
+
1. PyTorch encoder vs ONNX encoder, max-abs-diff < 1e-3
|
| 80 |
+
2. PyTorch decoder step vs ONNX decoder step, max-abs-diff < 1e-3
|
| 81 |
+
3. **End-to-end**: Cactus's native `generate(constrained=False)` vs a hand-rolled (encoder + decoder-step) ONNX loop β must produce byte-identical token sequences.
|
| 82 |
+
|
| 83 |
+
If (1) or (2) pass but (3) fails, the bug is almost certainly in the multi-step KV-cache handling. The most common cause is re-applying RoPE to the concatenated `(past_k + new_k)` tensor instead of just `new_k` β this double-rotates cached keys on every step. The fix lives in `needle_torch/layers.py`'s `MultiHeadAttention.forward`.
|
| 84 |
+
|
| 85 |
+
### 7. Dump the SentencePiece tokenizer for browser use
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
uv run python dump_tokenizer.py
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Copies `needle.model` and special-token IDs to where the browser can fetch them, plus emits parity goldens for the TS tokenizer port.
|
| 92 |
+
|
| 93 |
+
### 8. Push to HF Hub
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
uv run python upload_to_hf.py --repo your-username/your-finetune-onnx
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
Uploads:
|
| 100 |
+
|
| 101 |
+
- `encoder.onnx`, `decoder_step.onnx`
|
| 102 |
+
- `needle.model`, `tokenizer-specials.json`
|
| 103 |
+
- A model-card README with provenance and the parity numbers you measured
|
| 104 |
+
- The pipeline scripts themselves (so downstream finetuners can repeat the recipe)
|
| 105 |
+
|
| 106 |
+
## Plug it into the browser
|
| 107 |
+
|
| 108 |
+
The browser app at [onnx-community/needle-playground](https://huggingface.co/spaces/onnx-community/needle-playground) fetches from `onnx-community/needle-onnx` by default. To swap in your finetune, edit `web/src/config.ts`:
|
| 109 |
+
|
| 110 |
+
```typescript
|
| 111 |
+
export const MODEL_BASE_URL = import.meta.env.PROD
|
| 112 |
+
? 'https://huggingface.co/your-username/your-finetune-onnx/resolve/main'
|
| 113 |
+
: '/models-dev';
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Then `npm run build` and `deploy_space.py --repo your-username/your-finetune-playground` to ship.
|
| 117 |
+
|
| 118 |
+
## What's *not* supported
|
| 119 |
+
|
| 120 |
+
- Architecture changes beyond `TransformerConfig`'s field set (e.g. swapping ZCRMSNorm for LayerNorm, adding cross-layer parameter sharing not present in Cactus, etc.) β you'd need to edit `needle_torch/layers.py` and `model.py`. Update `notes/needle-internals.md` first so the change is documented.
|
| 121 |
+
- Quantization. The export is fp32. INT8/INT4 quantization is a separate post-processing step (`onnxruntime.quantization`); plumbing not included.
|
| 122 |
+
- Speech inputs (`enable_speech=True`). The port ignores the speech encoder branch.
|