| """PyTorch port vs onnxruntime β assert logit drift < 1e-3 (Task 7 + 8 + 9 home).""" |
| from pathlib import Path |
| import numpy as np |
| import torch |
| import onnxruntime as ort |
|
|
| from needle_torch import NeedleModel, TransformerConfig |
|
|
| ART = Path(__file__).resolve().parent / "artifacts" |
|
|
| PROD_CONFIG = TransformerConfig( |
| vocab_size=8192, d_model=512, num_heads=8, num_kv_heads=4, |
| num_encoder_layers=12, num_decoder_layers=8, |
| max_seq_len=1024, no_feedforward=True, |
| ) |
|
|
|
|
| def load_pt_model(): |
| m = NeedleModel(PROD_CONFIG) |
| m.train(False) |
| state = torch.load(ART / "needle_torch.pt", map_location="cpu", weights_only=True) |
| m.load_state_dict(state, strict=True) |
| return m |
|
|
|
|
| def verify_encoder(): |
| pt_model = load_pt_model() |
| sess = ort.InferenceSession(str(ART / "encoder.onnx"), providers=["CPUExecutionProvider"]) |
|
|
| rng = np.random.default_rng(0) |
| ids_np = rng.integers(low=0, high=8000, size=(1, 24)).astype(np.int64) |
|
|
| with torch.no_grad(): |
| pt_out = pt_model.encoder(torch.from_numpy(ids_np)).cpu().numpy() |
| ort_out = sess.run(None, {"input_ids": ids_np})[0] |
|
|
| diff = float(np.max(np.abs(pt_out - ort_out))) |
| mean = float(np.mean(np.abs(pt_out - ort_out))) |
| print(f"encoder parity: max-abs-diff={diff:.6f}, mean-abs-diff={mean:.6f}") |
| assert diff < 1e-3, f"encoder parity failed: {diff} >= 1e-3" |
| print("encoder parity OK") |
|
|
|
|
| def verify_decoder_step(): |
| """Single decoder step at past_seq=4 β non-trivial past_kv to catch caching bugs.""" |
| pt_model = load_pt_model() |
| dec_sess = ort.InferenceSession(str(ART / "decoder_step.onnx"), providers=["CPUExecutionProvider"]) |
|
|
| rng = np.random.default_rng(1) |
| |
| encoder_out = rng.standard_normal((1, 16, PROD_CONFIG.d_model)).astype(np.float32) |
| dec_id = np.array([[1]], dtype=np.int64) |
| head_dim = PROD_CONFIG.d_model // PROD_CONFIG.num_heads |
| past_kv = rng.standard_normal(( |
| PROD_CONFIG.num_decoder_layers, 2, 1, PROD_CONFIG.num_kv_heads, 4, head_dim |
| )).astype(np.float32) |
|
|
| with torch.no_grad(): |
| pt_logits, pt_present = pt_model.decoder.step( |
| torch.from_numpy(dec_id), |
| torch.from_numpy(encoder_out), |
| torch.from_numpy(past_kv), |
| ) |
| pt_logits_np = pt_logits.cpu().numpy() |
| pt_present_np = pt_present.cpu().numpy() |
|
|
| ort_logits, ort_present = dec_sess.run(None, { |
| "decoder_input_ids": dec_id, |
| "encoder_out": encoder_out, |
| "past_self_kv": past_kv, |
| }) |
|
|
| diff_logits = float(np.max(np.abs(pt_logits_np - ort_logits))) |
| diff_present = float(np.max(np.abs(pt_present_np - ort_present))) |
| print(f"decoder step parity: logits max-abs-diff={diff_logits:.6f}, present_kv max-abs-diff={diff_present:.6f}") |
| assert diff_logits < 1e-3, f"decoder logits drift: {diff_logits}" |
| assert diff_present < 1e-3, f"decoder kv drift: {diff_present}" |
| print("decoder step parity OK") |
|
|
|
|
| def verify_end_to_end(ckpt_repo="Cactus-Compute/needle", ckpt_file="needle.pkl"): |
| """Native Cactus generate() vs hand-rolled (encoder + decoder-step loop) via ONNX. |
| |
| The two paths use different decode schemes (Cactus re-runs the full decoder |
| each step; ours uses a step-based KV-cache loop), but with greedy argmax + the |
| per-step parity established in Tasks 2D + 7 + 8, the produced token sequences |
| must match. |
| """ |
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "external" / "needle")) |
| from huggingface_hub import hf_hub_download |
|
|
| from needle.model.architecture import SimpleAttentionNetwork, TransformerConfig as FlaxConfig |
| from needle.model.run import generate as cactus_generate, _build_encoder_input, load_checkpoint |
| from needle.dataset.tokenizer import get_tokenizer |
|
|
| |
| ckpt_path = hf_hub_download(repo_id=ckpt_repo, filename=ckpt_file) |
| flax_params, flax_cfg = load_checkpoint(ckpt_path) |
| flax_model = SimpleAttentionNetwork(flax_cfg) |
| tokenizer = get_tokenizer() |
|
|
| query = "set a 5 min timer" |
| tools = '[{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration"}}}]' |
|
|
| native_text = cactus_generate( |
| flax_model, flax_params, tokenizer, query, tools=tools, |
| max_gen_len=64, stream=False, normalize=False, constrained=False, |
| ) |
| print(f"native generate output text: {native_text!r}") |
|
|
| |
| enc_sess = ort.InferenceSession(str(ART / "encoder.onnx"), providers=["CPUExecutionProvider"]) |
| dec_sess = ort.InferenceSession(str(ART / "decoder_step.onnx"), providers=["CPUExecutionProvider"]) |
|
|
| enc_tokens = _build_encoder_input(tokenizer, query, tools, max_enc_len=1024) |
| enc_input = np.array([enc_tokens], dtype=np.int64) |
| encoder_out = enc_sess.run(None, {"input_ids": enc_input})[0] |
|
|
| head_dim = PROD_CONFIG.d_model // PROD_CONFIG.num_heads |
| past_kv = np.zeros(( |
| PROD_CONFIG.num_decoder_layers, 2, 1, PROD_CONFIG.num_kv_heads, 0, head_dim |
| ), dtype=np.float32) |
|
|
| eos_id = tokenizer.eos_token_id |
| next_id = eos_id |
| ort_generated = [] |
| for _ in range(64): |
| logits, past_kv = dec_sess.run(None, { |
| "decoder_input_ids": np.array([[next_id]], dtype=np.int64), |
| "encoder_out": encoder_out, |
| "past_self_kv": past_kv, |
| }) |
| next_id = int(np.argmax(logits[0, 0])) |
| if next_id == eos_id: |
| break |
| ort_generated.append(next_id) |
|
|
| ort_text = tokenizer.decode(ort_generated) |
| if ort_text.startswith("<tool_call>"): |
| ort_text = ort_text[len("<tool_call>"):] |
| print(f"ort generate output text: {ort_text!r}") |
|
|
| assert native_text == ort_text, ( |
| f"end-to-end output text differs!\n" |
| f" native: {native_text!r}\n" |
| f" ort: {ort_text!r}" |
| ) |
| print("end-to-end parity OK β Cactus native == ONNX hand-rolled loop") |
|
|
|
|
| if __name__ == "__main__": |
| verify_encoder() |
| verify_decoder_step() |
| import argparse |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt-repo", default="Cactus-Compute/needle", |
| help="HF repo for the upstream Flax checkpoint (default: Cactus-Compute/needle)") |
| p.add_argument("--ckpt-file", default="needle.pkl", |
| help="Filename within the repo (default: needle.pkl)") |
| args, _ = p.parse_known_args() |
| verify_end_to_end(args.ckpt_repo, args.ckpt_file) |
|
|