needle-onnx / verify_parity.py
shreyask's picture
Upload verify_parity.py with huggingface_hub
f6077fc verified
"""PyTorch port vs onnxruntime β€” assert logit drift < 1e-3 (Task 7 + 8 + 9 home)."""
from pathlib import Path
import numpy as np
import torch
import onnxruntime as ort
from needle_torch import NeedleModel, TransformerConfig
ART = Path(__file__).resolve().parent / "artifacts"
PROD_CONFIG = TransformerConfig(
vocab_size=8192, d_model=512, num_heads=8, num_kv_heads=4,
num_encoder_layers=12, num_decoder_layers=8,
max_seq_len=1024, no_feedforward=True,
)
def load_pt_model():
m = NeedleModel(PROD_CONFIG)
m.train(False)
state = torch.load(ART / "needle_torch.pt", map_location="cpu", weights_only=True)
m.load_state_dict(state, strict=True)
return m
def verify_encoder():
pt_model = load_pt_model()
sess = ort.InferenceSession(str(ART / "encoder.onnx"), providers=["CPUExecutionProvider"])
rng = np.random.default_rng(0)
ids_np = rng.integers(low=0, high=8000, size=(1, 24)).astype(np.int64)
with torch.no_grad():
pt_out = pt_model.encoder(torch.from_numpy(ids_np)).cpu().numpy()
ort_out = sess.run(None, {"input_ids": ids_np})[0]
diff = float(np.max(np.abs(pt_out - ort_out)))
mean = float(np.mean(np.abs(pt_out - ort_out)))
print(f"encoder parity: max-abs-diff={diff:.6f}, mean-abs-diff={mean:.6f}")
assert diff < 1e-3, f"encoder parity failed: {diff} >= 1e-3"
print("encoder parity OK")
def verify_decoder_step():
"""Single decoder step at past_seq=4 β€” non-trivial past_kv to catch caching bugs."""
pt_model = load_pt_model()
dec_sess = ort.InferenceSession(str(ART / "decoder_step.onnx"), providers=["CPUExecutionProvider"])
rng = np.random.default_rng(1)
# Encoder output (just random β€” both runtimes see the same)
encoder_out = rng.standard_normal((1, 16, PROD_CONFIG.d_model)).astype(np.float32)
dec_id = np.array([[1]], dtype=np.int64) # EOS-prefix
head_dim = PROD_CONFIG.d_model // PROD_CONFIG.num_heads
past_kv = rng.standard_normal((
PROD_CONFIG.num_decoder_layers, 2, 1, PROD_CONFIG.num_kv_heads, 4, head_dim
)).astype(np.float32)
with torch.no_grad():
pt_logits, pt_present = pt_model.decoder.step(
torch.from_numpy(dec_id),
torch.from_numpy(encoder_out),
torch.from_numpy(past_kv),
)
pt_logits_np = pt_logits.cpu().numpy()
pt_present_np = pt_present.cpu().numpy()
ort_logits, ort_present = dec_sess.run(None, {
"decoder_input_ids": dec_id,
"encoder_out": encoder_out,
"past_self_kv": past_kv,
})
diff_logits = float(np.max(np.abs(pt_logits_np - ort_logits)))
diff_present = float(np.max(np.abs(pt_present_np - ort_present)))
print(f"decoder step parity: logits max-abs-diff={diff_logits:.6f}, present_kv max-abs-diff={diff_present:.6f}")
assert diff_logits < 1e-3, f"decoder logits drift: {diff_logits}"
assert diff_present < 1e-3, f"decoder kv drift: {diff_present}"
print("decoder step parity OK")
def verify_end_to_end(ckpt_repo="Cactus-Compute/needle", ckpt_file="needle.pkl"):
"""Native Cactus generate() vs hand-rolled (encoder + decoder-step loop) via ONNX.
The two paths use different decode schemes (Cactus re-runs the full decoder
each step; ours uses a step-based KV-cache loop), but with greedy argmax + the
per-step parity established in Tasks 2D + 7 + 8, the produced token sequences
must match.
"""
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "external" / "needle"))
from huggingface_hub import hf_hub_download
from needle.model.architecture import SimpleAttentionNetwork, TransformerConfig as FlaxConfig
from needle.model.run import generate as cactus_generate, _build_encoder_input, load_checkpoint
from needle.dataset.tokenizer import get_tokenizer
# ── Native Cactus generate (constrained=False, deterministic argmax) ──
ckpt_path = hf_hub_download(repo_id=ckpt_repo, filename=ckpt_file)
flax_params, flax_cfg = load_checkpoint(ckpt_path)
flax_model = SimpleAttentionNetwork(flax_cfg)
tokenizer = get_tokenizer()
query = "set a 5 min timer"
tools = '[{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration"}}}]'
native_text = cactus_generate(
flax_model, flax_params, tokenizer, query, tools=tools,
max_gen_len=64, stream=False, normalize=False, constrained=False,
)
print(f"native generate output text: {native_text!r}")
# ── Hand-rolled ONNX KV-cache loop ──
enc_sess = ort.InferenceSession(str(ART / "encoder.onnx"), providers=["CPUExecutionProvider"])
dec_sess = ort.InferenceSession(str(ART / "decoder_step.onnx"), providers=["CPUExecutionProvider"])
enc_tokens = _build_encoder_input(tokenizer, query, tools, max_enc_len=1024)
enc_input = np.array([enc_tokens], dtype=np.int64)
encoder_out = enc_sess.run(None, {"input_ids": enc_input})[0]
head_dim = PROD_CONFIG.d_model // PROD_CONFIG.num_heads
past_kv = np.zeros((
PROD_CONFIG.num_decoder_layers, 2, 1, PROD_CONFIG.num_kv_heads, 0, head_dim
), dtype=np.float32)
eos_id = tokenizer.eos_token_id
next_id = eos_id # decoder seeded with EOS per Cactus convention
ort_generated = []
for _ in range(64):
logits, past_kv = dec_sess.run(None, {
"decoder_input_ids": np.array([[next_id]], dtype=np.int64),
"encoder_out": encoder_out,
"past_self_kv": past_kv,
})
next_id = int(np.argmax(logits[0, 0]))
if next_id == eos_id:
break
ort_generated.append(next_id)
ort_text = tokenizer.decode(ort_generated)
if ort_text.startswith("<tool_call>"):
ort_text = ort_text[len("<tool_call>"):]
print(f"ort generate output text: {ort_text!r}")
assert native_text == ort_text, (
f"end-to-end output text differs!\n"
f" native: {native_text!r}\n"
f" ort: {ort_text!r}"
)
print("end-to-end parity OK β€” Cactus native == ONNX hand-rolled loop")
if __name__ == "__main__":
verify_encoder()
verify_decoder_step()
import argparse
p = argparse.ArgumentParser()
p.add_argument("--ckpt-repo", default="Cactus-Compute/needle",
help="HF repo for the upstream Flax checkpoint (default: Cactus-Compute/needle)")
p.add_argument("--ckpt-file", default="needle.pkl",
help="Filename within the repo (default: needle.pkl)")
args, _ = p.parse_known_args()
verify_end_to_end(args.ckpt_repo, args.ckpt_file)