"""Trace the PyTorch port's encoder + decoder-step to ONNX. Config is read from `artifacts/needle_torch.config.json` (written by `convert_weights.py`), so this script works as-is for any finetuned Cactus-architecture model. Override the input paths via CLI flags if your artifacts live elsewhere. """ import argparse import json from pathlib import Path import torch from needle_torch import NeedleModel, TransformerConfig ART = Path(__file__).resolve().parent / "artifacts" ART.mkdir(exist_ok=True) def load_pt_model(state_path: Path, config_path: Path): cfg_dict = json.loads(config_path.read_text()) cfg = TransformerConfig(**cfg_dict) m = NeedleModel(cfg) m.train(False) state = torch.load(state_path, map_location="cpu", weights_only=True) m.load_state_dict(state, strict=True) return m, cfg class DecoderStepWrapper(torch.nn.Module): """Wraps Decoder.step in a Module so torch.onnx.export traces it cleanly.""" def __init__(self, decoder): super().__init__() self.decoder = decoder def forward(self, decoder_input_ids, encoder_out, past_self_kv): return self.decoder.step(decoder_input_ids, encoder_out, past_self_kv) def export_encoder_to(model, out_path: Path): encoder = model.encoder dummy_ids = torch.zeros(1, 16, dtype=torch.long) torch.onnx.export( encoder, (dummy_ids,), out_path, input_names=["input_ids"], output_names=["encoder_out"], dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, "encoder_out": {0: "batch", 1: "seq"}}, opset_version=17, do_constant_folding=True, external_data=False, dynamo=False, ) sz = out_path.stat().st_size print(f"{out_path.name} written ({sz / 1e6:.1f} MB)") def export_decoder_to(model, cfg: TransformerConfig, out_path: Path): wrapper = DecoderStepWrapper(model.decoder); wrapper.train(False) head_dim = cfg.d_model // cfg.num_heads batch, enc_seq, past_seq = 1, 16, 4 dummy_dec_ids = torch.zeros(batch, 1, dtype=torch.long) dummy_enc_out = torch.zeros(batch, enc_seq, cfg.d_model, dtype=torch.float32) dummy_past_kv = torch.zeros( cfg.num_decoder_layers, 2, batch, cfg.num_kv_heads, past_seq, head_dim, dtype=torch.float32, ) torch.onnx.export( wrapper, (dummy_dec_ids, dummy_enc_out, dummy_past_kv), out_path, input_names=["decoder_input_ids", "encoder_out", "past_self_kv"], output_names=["logits", "present_self_kv"], dynamic_axes={ "decoder_input_ids": {0: "batch"}, "encoder_out": {0: "batch", 1: "enc_seq"}, "past_self_kv": {2: "batch", 4: "past_seq"}, "logits": {0: "batch"}, "present_self_kv": {2: "batch", 4: "present_seq"}, }, opset_version=17, do_constant_folding=True, external_data=False, dynamo=False, ) sz = out_path.stat().st_size print(f"{out_path.name} written ({sz / 1e6:.1f} MB)") if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--state", default=str(ART / "needle_torch.pt"), help="PyTorch state_dict produced by convert_weights.py") p.add_argument("--config", default=str(ART / "needle_torch.config.json"), help="Config JSON produced by convert_weights.py (same dim shape as the source ckpt)") p.add_argument("--encoder-out", default=str(ART / "encoder.onnx")) p.add_argument("--decoder-out", default=str(ART / "decoder_step.onnx")) args = p.parse_args() m, cfg = load_pt_model(Path(args.state), Path(args.config)) export_encoder_to(m, Path(args.encoder_out)) export_decoder_to(m, cfg, Path(args.decoder_out))