Upload export_onnx.py with huggingface_hub
Browse files- export_onnx.py +92 -0
export_onnx.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Trace the PyTorch port's encoder + decoder-step to ONNX.
|
| 2 |
+
|
| 3 |
+
Config is read from `artifacts/needle_torch.config.json` (written by
|
| 4 |
+
`convert_weights.py`), so this script works as-is for any finetuned
|
| 5 |
+
Cactus-architecture model. Override the input paths via CLI flags if
|
| 6 |
+
your artifacts live elsewhere.
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from needle_torch import NeedleModel, TransformerConfig
|
| 14 |
+
|
| 15 |
+
ART = Path(__file__).resolve().parent / "artifacts"
|
| 16 |
+
ART.mkdir(exist_ok=True)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_pt_model(state_path: Path, config_path: Path):
|
| 20 |
+
cfg_dict = json.loads(config_path.read_text())
|
| 21 |
+
cfg = TransformerConfig(**cfg_dict)
|
| 22 |
+
m = NeedleModel(cfg)
|
| 23 |
+
m.train(False)
|
| 24 |
+
state = torch.load(state_path, map_location="cpu", weights_only=True)
|
| 25 |
+
m.load_state_dict(state, strict=True)
|
| 26 |
+
return m, cfg
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DecoderStepWrapper(torch.nn.Module):
|
| 30 |
+
"""Wraps Decoder.step in a Module so torch.onnx.export traces it cleanly."""
|
| 31 |
+
def __init__(self, decoder):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.decoder = decoder
|
| 34 |
+
|
| 35 |
+
def forward(self, decoder_input_ids, encoder_out, past_self_kv):
|
| 36 |
+
return self.decoder.step(decoder_input_ids, encoder_out, past_self_kv)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def export_encoder_to(model, out_path: Path):
|
| 40 |
+
encoder = model.encoder
|
| 41 |
+
dummy_ids = torch.zeros(1, 16, dtype=torch.long)
|
| 42 |
+
torch.onnx.export(
|
| 43 |
+
encoder, (dummy_ids,), out_path,
|
| 44 |
+
input_names=["input_ids"], output_names=["encoder_out"],
|
| 45 |
+
dynamic_axes={"input_ids": {0: "batch", 1: "seq"},
|
| 46 |
+
"encoder_out": {0: "batch", 1: "seq"}},
|
| 47 |
+
opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
|
| 48 |
+
)
|
| 49 |
+
sz = out_path.stat().st_size
|
| 50 |
+
print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def export_decoder_to(model, cfg: TransformerConfig, out_path: Path):
|
| 54 |
+
wrapper = DecoderStepWrapper(model.decoder); wrapper.train(False)
|
| 55 |
+
head_dim = cfg.d_model // cfg.num_heads
|
| 56 |
+
batch, enc_seq, past_seq = 1, 16, 4
|
| 57 |
+
dummy_dec_ids = torch.zeros(batch, 1, dtype=torch.long)
|
| 58 |
+
dummy_enc_out = torch.zeros(batch, enc_seq, cfg.d_model, dtype=torch.float32)
|
| 59 |
+
dummy_past_kv = torch.zeros(
|
| 60 |
+
cfg.num_decoder_layers, 2, batch, cfg.num_kv_heads, past_seq, head_dim,
|
| 61 |
+
dtype=torch.float32,
|
| 62 |
+
)
|
| 63 |
+
torch.onnx.export(
|
| 64 |
+
wrapper, (dummy_dec_ids, dummy_enc_out, dummy_past_kv), out_path,
|
| 65 |
+
input_names=["decoder_input_ids", "encoder_out", "past_self_kv"],
|
| 66 |
+
output_names=["logits", "present_self_kv"],
|
| 67 |
+
dynamic_axes={
|
| 68 |
+
"decoder_input_ids": {0: "batch"},
|
| 69 |
+
"encoder_out": {0: "batch", 1: "enc_seq"},
|
| 70 |
+
"past_self_kv": {2: "batch", 4: "past_seq"},
|
| 71 |
+
"logits": {0: "batch"},
|
| 72 |
+
"present_self_kv": {2: "batch", 4: "present_seq"},
|
| 73 |
+
},
|
| 74 |
+
opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
|
| 75 |
+
)
|
| 76 |
+
sz = out_path.stat().st_size
|
| 77 |
+
print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
p = argparse.ArgumentParser()
|
| 82 |
+
p.add_argument("--state", default=str(ART / "needle_torch.pt"),
|
| 83 |
+
help="PyTorch state_dict produced by convert_weights.py")
|
| 84 |
+
p.add_argument("--config", default=str(ART / "needle_torch.config.json"),
|
| 85 |
+
help="Config JSON produced by convert_weights.py (same dim shape as the source ckpt)")
|
| 86 |
+
p.add_argument("--encoder-out", default=str(ART / "encoder.onnx"))
|
| 87 |
+
p.add_argument("--decoder-out", default=str(ART / "decoder_step.onnx"))
|
| 88 |
+
args = p.parse_args()
|
| 89 |
+
|
| 90 |
+
m, cfg = load_pt_model(Path(args.state), Path(args.config))
|
| 91 |
+
export_encoder_to(m, Path(args.encoder_out))
|
| 92 |
+
export_decoder_to(m, cfg, Path(args.decoder_out))
|