shreyask commited on
Commit
cd8d04f
Β·
verified Β·
1 Parent(s): 52bbf4c

Upload PORTING.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. PORTING.md +122 -0
PORTING.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Porting another Cactus-trained model to ONNX
2
+
3
+ The scripts in this repo were built around the published [Cactus-Compute/needle](https://huggingface.co/Cactus-Compute/needle) checkpoint, but they work as-is for **any** model trained with the upstream [Cactus pipeline](https://github.com/cactus-compute/needle). If you've finetuned Needle (or trained a new Simple-Attention-Network variant) and want a browser-ready ONNX export, this is the recipe.
4
+
5
+ The `needle_torch/` package is parametric on `TransformerConfig` β€” it doesn't assume the production 26M dims. `convert_weights.py` reads the config straight out of your checkpoint's payload and writes it next to the `.pt` file. `export_onnx.py` reads that config back. So as long as your finetuned model uses the same architecture (Simple Attention Network: encoder-decoder, GQA, RoPE, ZCRMSNorm, optionally no FFN), the only thing that changes is the source repo/filename.
6
+
7
+ ## Prerequisites
8
+
9
+ - A Cactus-format checkpoint published on HF Hub. The checkpoint must be a serialized dict with the shape `{"config": {...}, "params": <Flax pytree>}` β€” this is what `needle/training/{train,pretrain}.py` saves.
10
+ - A modern Python (β‰₯ 3.11) with `uv` installed.
11
+ - ~3 GB of disk for the full pipeline (Flax + PyTorch + ONNX runtimes).
12
+
13
+ ## Step-by-step
14
+
15
+ ### 1. Clone Cactus and this pipeline
16
+
17
+ ```bash
18
+ git clone https://github.com/cactus-compute/needle.git external/needle
19
+ # Plus this repo's `export/` directory and `needle_torch/` package
20
+ ```
21
+
22
+ ### 2. Set up the env
23
+
24
+ ```bash
25
+ cd export
26
+ uv sync
27
+ ```
28
+
29
+ ### 3. Convert your checkpoint to a PyTorch `state_dict`
30
+
31
+ ```bash
32
+ uv run python convert_weights.py \
33
+ --ckpt-repo your-username/your-finetune \
34
+ --ckpt-file weights.pkl
35
+ ```
36
+
37
+ This downloads the checkpoint, walks the Flax pytree, copies tensors into a `NeedleModel` (parametric on the embedded config), and saves:
38
+
39
+ - `artifacts/needle_torch.pt` β€” PyTorch state_dict
40
+ - `artifacts/needle_torch.config.json` β€” config dict (used by `export_onnx.py`)
41
+
42
+ ### 4. (Strongly recommended) Verify Flax ↔ PyTorch parity
43
+
44
+ ```bash
45
+ uv run python verify_port_parity.py
46
+ ```
47
+
48
+ Should print `port parity OK (< 1e-3)`. If parity fails, the conversion has a bug β€” fix it before exporting to ONNX. Common culprits:
49
+
50
+ - ZCRMSNorm formula: must be `(1 + Ξ³) Β· x / RMS(x)` with Ξ³ init zero, NOT the standard `Ξ³ Β· x / RMS(x)`.
51
+ - GQA broadcast: `k.repeat_interleave(repeats, dim=heads)` *before* attention, matching Flax's `jnp.repeat(k, repeats, axis=heads)`.
52
+ - Q/K-norm position: applied *before* RoPE.
53
+ - Linear weight transposition: Flax stores `(in, out)`, PyTorch is `(out, in)`. The script handles this on copy.
54
+ - Tied embedding: appears under three keys in PyTorch state_dict (`embedding.weight`, `encoder.embedding.weight`, `decoder.embedding.weight`); all three must be set to the same tensor.
55
+
56
+ ### 5. Export to ONNX
57
+
58
+ ```bash
59
+ uv run python export_onnx.py
60
+ ```
61
+
62
+ Produces:
63
+
64
+ - `artifacts/encoder.onnx` β€” encoder graph (input_ids β†’ encoder_out)
65
+ - `artifacts/decoder_step.onnx` β€” one decoder step with KV-cache I/O (decoder_input_ids, encoder_out, past_self_kv β†’ logits, present_self_kv)
66
+
67
+ Both files are self-contained (no external `.data` sidecar). The decoder is exported as a *single step* so the browser-side runs it in a loop with streaming output and a growing KV cache, rather than tracing a full `Loop` op into the graph.
68
+
69
+ ### 6. Verify PyTorch ↔ ONNX parity (and end-to-end)
70
+
71
+ ```bash
72
+ uv run python verify_parity.py \
73
+ --ckpt-repo your-username/your-finetune \
74
+ --ckpt-file weights.pkl
75
+ ```
76
+
77
+ Runs three checks:
78
+
79
+ 1. PyTorch encoder vs ONNX encoder, max-abs-diff < 1e-3
80
+ 2. PyTorch decoder step vs ONNX decoder step, max-abs-diff < 1e-3
81
+ 3. **End-to-end**: Cactus's native `generate(constrained=False)` vs a hand-rolled (encoder + decoder-step) ONNX loop β€” must produce byte-identical token sequences.
82
+
83
+ If (1) or (2) pass but (3) fails, the bug is almost certainly in the multi-step KV-cache handling. The most common cause is re-applying RoPE to the concatenated `(past_k + new_k)` tensor instead of just `new_k` β€” this double-rotates cached keys on every step. The fix lives in `needle_torch/layers.py`'s `MultiHeadAttention.forward`.
84
+
85
+ ### 7. Dump the SentencePiece tokenizer for browser use
86
+
87
+ ```bash
88
+ uv run python dump_tokenizer.py
89
+ ```
90
+
91
+ Copies `needle.model` and special-token IDs to where the browser can fetch them, plus emits parity goldens for the TS tokenizer port.
92
+
93
+ ### 8. Push to HF Hub
94
+
95
+ ```bash
96
+ uv run python upload_to_hf.py --repo your-username/your-finetune-onnx
97
+ ```
98
+
99
+ Uploads:
100
+
101
+ - `encoder.onnx`, `decoder_step.onnx`
102
+ - `needle.model`, `tokenizer-specials.json`
103
+ - A model-card README with provenance and the parity numbers you measured
104
+ - The pipeline scripts themselves (so downstream finetuners can repeat the recipe)
105
+
106
+ ## Plug it into the browser
107
+
108
+ The browser app at [onnx-community/needle-playground](https://huggingface.co/spaces/onnx-community/needle-playground) fetches from `onnx-community/needle-onnx` by default. To swap in your finetune, edit `web/src/config.ts`:
109
+
110
+ ```typescript
111
+ export const MODEL_BASE_URL = import.meta.env.PROD
112
+ ? 'https://huggingface.co/your-username/your-finetune-onnx/resolve/main'
113
+ : '/models-dev';
114
+ ```
115
+
116
+ Then `npm run build` and `deploy_space.py --repo your-username/your-finetune-playground` to ship.
117
+
118
+ ## What's *not* supported
119
+
120
+ - Architecture changes beyond `TransformerConfig`'s field set (e.g. swapping ZCRMSNorm for LayerNorm, adding cross-layer parameter sharing not present in Cactus, etc.) β€” you'd need to edit `needle_torch/layers.py` and `model.py`. Update `notes/needle-internals.md` first so the change is documented.
121
+ - Quantization. The export is fp32. INT8/INT4 quantization is a separate post-processing step (`onnxruntime.quantization`); plumbing not included.
122
+ - Speech inputs (`enable_speech=True`). The port ignores the speech encoder branch.