needle-onnx / export_onnx.py
shreyask's picture
Upload export_onnx.py with huggingface_hub
76fda9f verified
"""Trace the PyTorch port's encoder + decoder-step to ONNX.
Config is read from `artifacts/needle_torch.config.json` (written by
`convert_weights.py`), so this script works as-is for any finetuned
Cactus-architecture model. Override the input paths via CLI flags if
your artifacts live elsewhere.
"""
import argparse
import json
from pathlib import Path
import torch
from needle_torch import NeedleModel, TransformerConfig
ART = Path(__file__).resolve().parent / "artifacts"
ART.mkdir(exist_ok=True)
def load_pt_model(state_path: Path, config_path: Path):
cfg_dict = json.loads(config_path.read_text())
cfg = TransformerConfig(**cfg_dict)
m = NeedleModel(cfg)
m.train(False)
state = torch.load(state_path, map_location="cpu", weights_only=True)
m.load_state_dict(state, strict=True)
return m, cfg
class DecoderStepWrapper(torch.nn.Module):
"""Wraps Decoder.step in a Module so torch.onnx.export traces it cleanly."""
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
def forward(self, decoder_input_ids, encoder_out, past_self_kv):
return self.decoder.step(decoder_input_ids, encoder_out, past_self_kv)
def export_encoder_to(model, out_path: Path):
encoder = model.encoder
dummy_ids = torch.zeros(1, 16, dtype=torch.long)
torch.onnx.export(
encoder, (dummy_ids,), out_path,
input_names=["input_ids"], output_names=["encoder_out"],
dynamic_axes={"input_ids": {0: "batch", 1: "seq"},
"encoder_out": {0: "batch", 1: "seq"}},
opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
)
sz = out_path.stat().st_size
print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")
def export_decoder_to(model, cfg: TransformerConfig, out_path: Path):
wrapper = DecoderStepWrapper(model.decoder); wrapper.train(False)
head_dim = cfg.d_model // cfg.num_heads
batch, enc_seq, past_seq = 1, 16, 4
dummy_dec_ids = torch.zeros(batch, 1, dtype=torch.long)
dummy_enc_out = torch.zeros(batch, enc_seq, cfg.d_model, dtype=torch.float32)
dummy_past_kv = torch.zeros(
cfg.num_decoder_layers, 2, batch, cfg.num_kv_heads, past_seq, head_dim,
dtype=torch.float32,
)
torch.onnx.export(
wrapper, (dummy_dec_ids, dummy_enc_out, dummy_past_kv), out_path,
input_names=["decoder_input_ids", "encoder_out", "past_self_kv"],
output_names=["logits", "present_self_kv"],
dynamic_axes={
"decoder_input_ids": {0: "batch"},
"encoder_out": {0: "batch", 1: "enc_seq"},
"past_self_kv": {2: "batch", 4: "past_seq"},
"logits": {0: "batch"},
"present_self_kv": {2: "batch", 4: "present_seq"},
},
opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
)
sz = out_path.stat().st_size
print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--state", default=str(ART / "needle_torch.pt"),
help="PyTorch state_dict produced by convert_weights.py")
p.add_argument("--config", default=str(ART / "needle_torch.config.json"),
help="Config JSON produced by convert_weights.py (same dim shape as the source ckpt)")
p.add_argument("--encoder-out", default=str(ART / "encoder.onnx"))
p.add_argument("--decoder-out", default=str(ART / "decoder_step.onnx"))
args = p.parse_args()
m, cfg = load_pt_model(Path(args.state), Path(args.config))
export_encoder_to(m, Path(args.encoder_out))
export_decoder_to(m, cfg, Path(args.decoder_out))