needle / README.md
hmunachii's picture
Upload README.md with huggingface_hub
8a204f0 verified
|
raw
history blame
3.26 kB
metadata
license: mit
library_name: jax
tags:
  - function-calling
  - tool-use
  - encoder-decoder
  - edge
  - on-device
  - jax
  - flax
datasets:
  - Cactus-Compute/tool-calls

Needle

A 26M parameter encoder-decoder transformer for on-device function calling, built on a "Simple Attention Network" architecture (no feedforward layers).

Distilled from Gemini 3.1 Flash Lite. Runs at 6000 tok/s prefill and 1200 tok/s decode on Cactus.

Model Details

Parameters 26M
Architecture Encoder-decoder, pure attention (no FFN)
Encoder 12 layers, GQA (8H/4KV), RoPE, gated residuals
Decoder 8 layers, self-attn + cross-attn, gated residuals
d_model 512
Vocab 8192 (SentencePiece BPE)
Norm ZCRMSNorm (zero-centered, init=0)
Precision bfloat16 (INT4 QAT during training)
Pretraining 200B tokens on 16x TPU v6e (27hrs)
Post-training 2B tokens of function call data (45mins)

Architecture

No feedforward layers. Each encoder block is gated self-attention; each decoder block is gated self-attention + gated cross-attention. The only nonlinearities are softmax and sigmoid.

See Simple Attention Networks for the full architectural breakdown.

Quickstart

git clone https://github.com/cactus-compute/needle.git
cd needle && source ./setup
needle ui

Opens a web UI at http://127.0.0.1:7860 where you can test and finetune on your own tools. Weights are auto-downloaded.

Usage (Python)

from src.model.run import load_checkpoint, generate
from src.model.architecture import EncoderDecoderTransformer
from src.dataset.dataset import get_tokenizer

params, config = load_checkpoint("checkpoints/needle.pkl")
model = EncoderDecoderTransformer(config)
tokenizer = get_tokenizer()

result = generate(
    model, params, tokenizer,
    query="What's the weather in San Francisco?",
    tools='[{"name":"get_weather","parameters":{"location":"string"}}]',
    stream=False,
)
print(result)
# [{"name":"get_weather","arguments":{"location":"San Francisco"}}]

Finetuning

Finetune on your own tools via the web UI or CLI:

# Web UI (generates data via Gemini, trains, evaluates, bundles result)
needle ui

# CLI
python -m src.training.finetune data.jsonl --checkpoint checkpoints/needle.pkl

File Format

The checkpoint is a Python pickle containing:

{
    "params": { ... },   # nested dict of numpy float16 arrays
    "config": { ... },   # TransformerConfig fields as dict
}

Load with:

import pickle
with open("needle.pkl", "rb") as f:
    data = pickle.load(f)

Training Data

Post-trained on Cactus-Compute/tool-calls, a synthesized dataset of 2M+ function calling examples spanning 15 tool categories (timers, messaging, media, navigation, smart home, fitness, etc.).

License

MIT

Citation

@misc{ndubuaku2026needle,
  title={Simple Attention Networks},
  author={Henry Ndubuaku},
  year={2026},
  url={https://github.com/cactus-compute/needle/blob/main/docs/simple_attention_networks.md}
}