shreyask commited on
Commit
76fda9f
·
verified ·
1 Parent(s): 03f1e75

Upload export_onnx.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. export_onnx.py +92 -0
export_onnx.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trace the PyTorch port's encoder + decoder-step to ONNX.
2
+
3
+ Config is read from `artifacts/needle_torch.config.json` (written by
4
+ `convert_weights.py`), so this script works as-is for any finetuned
5
+ Cactus-architecture model. Override the input paths via CLI flags if
6
+ your artifacts live elsewhere.
7
+ """
8
+ import argparse
9
+ import json
10
+ from pathlib import Path
11
+ import torch
12
+
13
+ from needle_torch import NeedleModel, TransformerConfig
14
+
15
+ ART = Path(__file__).resolve().parent / "artifacts"
16
+ ART.mkdir(exist_ok=True)
17
+
18
+
19
+ def load_pt_model(state_path: Path, config_path: Path):
20
+ cfg_dict = json.loads(config_path.read_text())
21
+ cfg = TransformerConfig(**cfg_dict)
22
+ m = NeedleModel(cfg)
23
+ m.train(False)
24
+ state = torch.load(state_path, map_location="cpu", weights_only=True)
25
+ m.load_state_dict(state, strict=True)
26
+ return m, cfg
27
+
28
+
29
+ class DecoderStepWrapper(torch.nn.Module):
30
+ """Wraps Decoder.step in a Module so torch.onnx.export traces it cleanly."""
31
+ def __init__(self, decoder):
32
+ super().__init__()
33
+ self.decoder = decoder
34
+
35
+ def forward(self, decoder_input_ids, encoder_out, past_self_kv):
36
+ return self.decoder.step(decoder_input_ids, encoder_out, past_self_kv)
37
+
38
+
39
+ def export_encoder_to(model, out_path: Path):
40
+ encoder = model.encoder
41
+ dummy_ids = torch.zeros(1, 16, dtype=torch.long)
42
+ torch.onnx.export(
43
+ encoder, (dummy_ids,), out_path,
44
+ input_names=["input_ids"], output_names=["encoder_out"],
45
+ dynamic_axes={"input_ids": {0: "batch", 1: "seq"},
46
+ "encoder_out": {0: "batch", 1: "seq"}},
47
+ opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
48
+ )
49
+ sz = out_path.stat().st_size
50
+ print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")
51
+
52
+
53
+ def export_decoder_to(model, cfg: TransformerConfig, out_path: Path):
54
+ wrapper = DecoderStepWrapper(model.decoder); wrapper.train(False)
55
+ head_dim = cfg.d_model // cfg.num_heads
56
+ batch, enc_seq, past_seq = 1, 16, 4
57
+ dummy_dec_ids = torch.zeros(batch, 1, dtype=torch.long)
58
+ dummy_enc_out = torch.zeros(batch, enc_seq, cfg.d_model, dtype=torch.float32)
59
+ dummy_past_kv = torch.zeros(
60
+ cfg.num_decoder_layers, 2, batch, cfg.num_kv_heads, past_seq, head_dim,
61
+ dtype=torch.float32,
62
+ )
63
+ torch.onnx.export(
64
+ wrapper, (dummy_dec_ids, dummy_enc_out, dummy_past_kv), out_path,
65
+ input_names=["decoder_input_ids", "encoder_out", "past_self_kv"],
66
+ output_names=["logits", "present_self_kv"],
67
+ dynamic_axes={
68
+ "decoder_input_ids": {0: "batch"},
69
+ "encoder_out": {0: "batch", 1: "enc_seq"},
70
+ "past_self_kv": {2: "batch", 4: "past_seq"},
71
+ "logits": {0: "batch"},
72
+ "present_self_kv": {2: "batch", 4: "present_seq"},
73
+ },
74
+ opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
75
+ )
76
+ sz = out_path.stat().st_size
77
+ print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ p = argparse.ArgumentParser()
82
+ p.add_argument("--state", default=str(ART / "needle_torch.pt"),
83
+ help="PyTorch state_dict produced by convert_weights.py")
84
+ p.add_argument("--config", default=str(ART / "needle_torch.config.json"),
85
+ help="Config JSON produced by convert_weights.py (same dim shape as the source ckpt)")
86
+ p.add_argument("--encoder-out", default=str(ART / "encoder.onnx"))
87
+ p.add_argument("--decoder-out", default=str(ART / "decoder_step.onnx"))
88
+ args = p.parse_args()
89
+
90
+ m, cfg = load_pt_model(Path(args.state), Path(args.config))
91
+ export_encoder_to(m, Path(args.encoder_out))
92
+ export_decoder_to(m, cfg, Path(args.decoder_out))