| --- |
| license: mit |
| library_name: jax |
| tags: |
| - function-calling |
| - tool-use |
| - encoder-decoder |
| - edge |
| - on-device |
| - jax |
| - flax |
| datasets: |
| - Cactus-Compute/tool-calls |
| --- |
| |
| # Needle |
|
|
| A 26M parameter encoder-decoder transformer for on-device function calling, built on a "Simple Attention Network" architecture (no feedforward layers). |
|
|
| Distilled from Gemini 3.1 Flash Lite. Runs at 6000 tok/s prefill and 1200 tok/s decode on [Cactus](https://github.com/cactus-compute/cactus). |
|
|
| ## Model Details |
|
|
| | | | |
| |---|---| |
| | Parameters | 26M | |
| | Architecture | Encoder-decoder, pure attention (no FFN) | |
| | Encoder | 12 layers, GQA (8H/4KV), RoPE, gated residuals | |
| | Decoder | 8 layers, self-attn + cross-attn, gated residuals | |
| | d_model | 512 | |
| | Vocab | 8192 (SentencePiece BPE) | |
| | Norm | ZCRMSNorm (zero-centered, init=0) | |
| | Precision | bfloat16 (INT4 QAT during training) | |
| | Pretraining | 200B tokens on 16x TPU v6e (27hrs) | |
| | Post-training | 2B tokens of function call data (45mins) | |
| |
| ## Architecture |
| |
| No feedforward layers. Each encoder block is gated self-attention; each decoder block is gated self-attention + gated cross-attention. The only nonlinearities are softmax and sigmoid. |
| |
| See [Simple Attention Networks](https://github.com/cactus-compute/needle/blob/main/docs/simple_attention_networks.md) for the full architectural breakdown. |
| |
| ## Quickstart |
| |
| ```bash |
| git clone https://github.com/cactus-compute/needle.git |
| cd needle && source ./setup |
| needle ui |
| ``` |
| |
| Opens a web UI at http://127.0.0.1:7860 where you can test and finetune on your own tools. Weights are auto-downloaded. |
| |
| ## Usage (Python) |
| |
| ```python |
| from src.model.run import load_checkpoint, generate |
| from src.model.architecture import EncoderDecoderTransformer |
| from src.dataset.dataset import get_tokenizer |
| |
| params, config = load_checkpoint("checkpoints/needle.pkl") |
| model = EncoderDecoderTransformer(config) |
| tokenizer = get_tokenizer() |
| |
| result = generate( |
| model, params, tokenizer, |
| query="What's the weather in San Francisco?", |
| tools='[{"name":"get_weather","parameters":{"location":"string"}}]', |
| stream=False, |
| ) |
| print(result) |
| # [{"name":"get_weather","arguments":{"location":"San Francisco"}}] |
| ``` |
| |
| ## Finetuning |
|
|
| Finetune on your own tools via the web UI or CLI: |
|
|
| ```bash |
| # Web UI (generates data via Gemini, trains, evaluates, bundles result) |
| needle ui |
| |
| # CLI |
| python -m src.training.finetune data.jsonl --checkpoint checkpoints/needle.pkl |
| ``` |
|
|
| ## File Format |
|
|
| The checkpoint is a Python pickle containing: |
|
|
| ```python |
| { |
| "params": { ... }, # nested dict of numpy float16 arrays |
| "config": { ... }, # TransformerConfig fields as dict |
| } |
| ``` |
|
|
| Load with: |
| ```python |
| import pickle |
| with open("needle.pkl", "rb") as f: |
| data = pickle.load(f) |
| ``` |
|
|
| ## Training Data |
|
|
| Post-trained on [Cactus-Compute/tool-calls](https://huggingface.co/datasets/Cactus-Compute/tool-calls), a synthesized dataset of 2M+ function calling examples spanning 15 tool categories (timers, messaging, media, navigation, smart home, fitness, etc.). |
|
|
| ## License |
|
|
| MIT |
|
|
| ## Citation |
|
|
| ``` |
| @misc{ndubuaku2026needle, |
| title={Simple Attention Networks}, |
| author={Henry Ndubuaku}, |
| year={2026}, |
| url={https://github.com/cactus-compute/needle/blob/main/docs/simple_attention_networks.md} |
| } |
| ``` |
|
|