"""PyTorch port vs onnxruntime — assert logit drift < 1e-3 (Task 7 + 8 + 9 home).""" from pathlib import Path import numpy as np import torch import onnxruntime as ort from needle_torch import NeedleModel, TransformerConfig ART = Path(__file__).resolve().parent / "artifacts" PROD_CONFIG = TransformerConfig( vocab_size=8192, d_model=512, num_heads=8, num_kv_heads=4, num_encoder_layers=12, num_decoder_layers=8, max_seq_len=1024, no_feedforward=True, ) def load_pt_model(): m = NeedleModel(PROD_CONFIG) m.train(False) state = torch.load(ART / "needle_torch.pt", map_location="cpu", weights_only=True) m.load_state_dict(state, strict=True) return m def verify_encoder(): pt_model = load_pt_model() sess = ort.InferenceSession(str(ART / "encoder.onnx"), providers=["CPUExecutionProvider"]) rng = np.random.default_rng(0) ids_np = rng.integers(low=0, high=8000, size=(1, 24)).astype(np.int64) with torch.no_grad(): pt_out = pt_model.encoder(torch.from_numpy(ids_np)).cpu().numpy() ort_out = sess.run(None, {"input_ids": ids_np})[0] diff = float(np.max(np.abs(pt_out - ort_out))) mean = float(np.mean(np.abs(pt_out - ort_out))) print(f"encoder parity: max-abs-diff={diff:.6f}, mean-abs-diff={mean:.6f}") assert diff < 1e-3, f"encoder parity failed: {diff} >= 1e-3" print("encoder parity OK") def verify_decoder_step(): """Single decoder step at past_seq=4 — non-trivial past_kv to catch caching bugs.""" pt_model = load_pt_model() dec_sess = ort.InferenceSession(str(ART / "decoder_step.onnx"), providers=["CPUExecutionProvider"]) rng = np.random.default_rng(1) # Encoder output (just random — both runtimes see the same) encoder_out = rng.standard_normal((1, 16, PROD_CONFIG.d_model)).astype(np.float32) dec_id = np.array([[1]], dtype=np.int64) # EOS-prefix head_dim = PROD_CONFIG.d_model // PROD_CONFIG.num_heads past_kv = rng.standard_normal(( PROD_CONFIG.num_decoder_layers, 2, 1, PROD_CONFIG.num_kv_heads, 4, head_dim )).astype(np.float32) with torch.no_grad(): pt_logits, pt_present = pt_model.decoder.step( torch.from_numpy(dec_id), torch.from_numpy(encoder_out), torch.from_numpy(past_kv), ) pt_logits_np = pt_logits.cpu().numpy() pt_present_np = pt_present.cpu().numpy() ort_logits, ort_present = dec_sess.run(None, { "decoder_input_ids": dec_id, "encoder_out": encoder_out, "past_self_kv": past_kv, }) diff_logits = float(np.max(np.abs(pt_logits_np - ort_logits))) diff_present = float(np.max(np.abs(pt_present_np - ort_present))) print(f"decoder step parity: logits max-abs-diff={diff_logits:.6f}, present_kv max-abs-diff={diff_present:.6f}") assert diff_logits < 1e-3, f"decoder logits drift: {diff_logits}" assert diff_present < 1e-3, f"decoder kv drift: {diff_present}" print("decoder step parity OK") def verify_end_to_end(ckpt_repo="Cactus-Compute/needle", ckpt_file="needle.pkl"): """Native Cactus generate() vs hand-rolled (encoder + decoder-step loop) via ONNX. The two paths use different decode schemes (Cactus re-runs the full decoder each step; ours uses a step-based KV-cache loop), but with greedy argmax + the per-step parity established in Tasks 2D + 7 + 8, the produced token sequences must match. """ import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "external" / "needle")) from huggingface_hub import hf_hub_download from needle.model.architecture import SimpleAttentionNetwork, TransformerConfig as FlaxConfig from needle.model.run import generate as cactus_generate, _build_encoder_input, load_checkpoint from needle.dataset.tokenizer import get_tokenizer # ── Native Cactus generate (constrained=False, deterministic argmax) ── ckpt_path = hf_hub_download(repo_id=ckpt_repo, filename=ckpt_file) flax_params, flax_cfg = load_checkpoint(ckpt_path) flax_model = SimpleAttentionNetwork(flax_cfg) tokenizer = get_tokenizer() query = "set a 5 min timer" tools = '[{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration"}}}]' native_text = cactus_generate( flax_model, flax_params, tokenizer, query, tools=tools, max_gen_len=64, stream=False, normalize=False, constrained=False, ) print(f"native generate output text: {native_text!r}") # ── Hand-rolled ONNX KV-cache loop ── enc_sess = ort.InferenceSession(str(ART / "encoder.onnx"), providers=["CPUExecutionProvider"]) dec_sess = ort.InferenceSession(str(ART / "decoder_step.onnx"), providers=["CPUExecutionProvider"]) enc_tokens = _build_encoder_input(tokenizer, query, tools, max_enc_len=1024) enc_input = np.array([enc_tokens], dtype=np.int64) encoder_out = enc_sess.run(None, {"input_ids": enc_input})[0] head_dim = PROD_CONFIG.d_model // PROD_CONFIG.num_heads past_kv = np.zeros(( PROD_CONFIG.num_decoder_layers, 2, 1, PROD_CONFIG.num_kv_heads, 0, head_dim ), dtype=np.float32) eos_id = tokenizer.eos_token_id next_id = eos_id # decoder seeded with EOS per Cactus convention ort_generated = [] for _ in range(64): logits, past_kv = dec_sess.run(None, { "decoder_input_ids": np.array([[next_id]], dtype=np.int64), "encoder_out": encoder_out, "past_self_kv": past_kv, }) next_id = int(np.argmax(logits[0, 0])) if next_id == eos_id: break ort_generated.append(next_id) ort_text = tokenizer.decode(ort_generated) if ort_text.startswith(""): ort_text = ort_text[len(""):] print(f"ort generate output text: {ort_text!r}") assert native_text == ort_text, ( f"end-to-end output text differs!\n" f" native: {native_text!r}\n" f" ort: {ort_text!r}" ) print("end-to-end parity OK — Cactus native == ONNX hand-rolled loop") if __name__ == "__main__": verify_encoder() verify_decoder_step() import argparse p = argparse.ArgumentParser() p.add_argument("--ckpt-repo", default="Cactus-Compute/needle", help="HF repo for the upstream Flax checkpoint (default: Cactus-Compute/needle)") p.add_argument("--ckpt-file", default="needle.pkl", help="Filename within the repo (default: needle.pkl)") args, _ = p.parse_known_args() verify_end_to_end(args.ckpt_repo, args.ckpt_file)