File size: 3,731 Bytes
76fda9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Trace the PyTorch port's encoder + decoder-step to ONNX.

Config is read from `artifacts/needle_torch.config.json` (written by
`convert_weights.py`), so this script works as-is for any finetuned
Cactus-architecture model. Override the input paths via CLI flags if
your artifacts live elsewhere.
"""
import argparse
import json
from pathlib import Path
import torch

from needle_torch import NeedleModel, TransformerConfig

ART = Path(__file__).resolve().parent / "artifacts"
ART.mkdir(exist_ok=True)


def load_pt_model(state_path: Path, config_path: Path):
    cfg_dict = json.loads(config_path.read_text())
    cfg = TransformerConfig(**cfg_dict)
    m = NeedleModel(cfg)
    m.train(False)
    state = torch.load(state_path, map_location="cpu", weights_only=True)
    m.load_state_dict(state, strict=True)
    return m, cfg


class DecoderStepWrapper(torch.nn.Module):
    """Wraps Decoder.step in a Module so torch.onnx.export traces it cleanly."""
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder

    def forward(self, decoder_input_ids, encoder_out, past_self_kv):
        return self.decoder.step(decoder_input_ids, encoder_out, past_self_kv)


def export_encoder_to(model, out_path: Path):
    encoder = model.encoder
    dummy_ids = torch.zeros(1, 16, dtype=torch.long)
    torch.onnx.export(
        encoder, (dummy_ids,), out_path,
        input_names=["input_ids"], output_names=["encoder_out"],
        dynamic_axes={"input_ids": {0: "batch", 1: "seq"},
                      "encoder_out": {0: "batch", 1: "seq"}},
        opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
    )
    sz = out_path.stat().st_size
    print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")


def export_decoder_to(model, cfg: TransformerConfig, out_path: Path):
    wrapper = DecoderStepWrapper(model.decoder); wrapper.train(False)
    head_dim = cfg.d_model // cfg.num_heads
    batch, enc_seq, past_seq = 1, 16, 4
    dummy_dec_ids = torch.zeros(batch, 1, dtype=torch.long)
    dummy_enc_out = torch.zeros(batch, enc_seq, cfg.d_model, dtype=torch.float32)
    dummy_past_kv = torch.zeros(
        cfg.num_decoder_layers, 2, batch, cfg.num_kv_heads, past_seq, head_dim,
        dtype=torch.float32,
    )
    torch.onnx.export(
        wrapper, (dummy_dec_ids, dummy_enc_out, dummy_past_kv), out_path,
        input_names=["decoder_input_ids", "encoder_out", "past_self_kv"],
        output_names=["logits", "present_self_kv"],
        dynamic_axes={
            "decoder_input_ids": {0: "batch"},
            "encoder_out":       {0: "batch", 1: "enc_seq"},
            "past_self_kv":      {2: "batch", 4: "past_seq"},
            "logits":            {0: "batch"},
            "present_self_kv":   {2: "batch", 4: "present_seq"},
        },
        opset_version=17, do_constant_folding=True, external_data=False, dynamo=False,
    )
    sz = out_path.stat().st_size
    print(f"{out_path.name} written ({sz / 1e6:.1f} MB)")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--state", default=str(ART / "needle_torch.pt"),
                   help="PyTorch state_dict produced by convert_weights.py")
    p.add_argument("--config", default=str(ART / "needle_torch.config.json"),
                   help="Config JSON produced by convert_weights.py (same dim shape as the source ckpt)")
    p.add_argument("--encoder-out", default=str(ART / "encoder.onnx"))
    p.add_argument("--decoder-out", default=str(ART / "decoder_step.onnx"))
    args = p.parse_args()

    m, cfg = load_pt_model(Path(args.state), Path(args.config))
    export_encoder_to(m, Path(args.encoder_out))
    export_decoder_to(m, cfg, Path(args.decoder_out))