| """Trace the PyTorch port's encoder + decoder-step to ONNX. |
| |
| Config is read from `artifacts/needle_torch.config.json` (written by |
| `convert_weights.py`), so this script works as-is for any finetuned |
| Cactus-architecture model. Override the input paths via CLI flags if |
| your artifacts live elsewhere. |
| """ |
| import argparse |
| import json |
| from pathlib import Path |
| import torch |
|
|
| from needle_torch import NeedleModel, TransformerConfig |
|
|
| ART = Path(__file__).resolve().parent / "artifacts" |
| ART.mkdir(exist_ok=True) |
|
|
|
|
| def load_pt_model(state_path: Path, config_path: Path): |
| cfg_dict = json.loads(config_path.read_text()) |
| cfg = TransformerConfig(**cfg_dict) |
| m = NeedleModel(cfg) |
| m.train(False) |
| state = torch.load(state_path, map_location="cpu", weights_only=True) |
| m.load_state_dict(state, strict=True) |
| return m, cfg |
|
|
|
|
| class DecoderStepWrapper(torch.nn.Module): |
| """Wraps Decoder.step in a Module so torch.onnx.export traces it cleanly.""" |
| def __init__(self, decoder): |
| super().__init__() |
| self.decoder = decoder |
|
|
| def forward(self, decoder_input_ids, encoder_out, past_self_kv): |
| return self.decoder.step(decoder_input_ids, encoder_out, past_self_kv) |
|
|
|
|
| def export_encoder_to(model, out_path: Path): |
| encoder = model.encoder |
| dummy_ids = torch.zeros(1, 16, dtype=torch.long) |
| torch.onnx.export( |
| encoder, (dummy_ids,), out_path, |
| input_names=["input_ids"], output_names=["encoder_out"], |
| dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, |
| "encoder_out": {0: "batch", 1: "seq"}}, |
| opset_version=17, do_constant_folding=True, external_data=False, dynamo=False, |
| ) |
| sz = out_path.stat().st_size |
| print(f"{out_path.name} written ({sz / 1e6:.1f} MB)") |
|
|
|
|
| def export_decoder_to(model, cfg: TransformerConfig, out_path: Path): |
| wrapper = DecoderStepWrapper(model.decoder); wrapper.train(False) |
| head_dim = cfg.d_model // cfg.num_heads |
| batch, enc_seq, past_seq = 1, 16, 4 |
| dummy_dec_ids = torch.zeros(batch, 1, dtype=torch.long) |
| dummy_enc_out = torch.zeros(batch, enc_seq, cfg.d_model, dtype=torch.float32) |
| dummy_past_kv = torch.zeros( |
| cfg.num_decoder_layers, 2, batch, cfg.num_kv_heads, past_seq, head_dim, |
| dtype=torch.float32, |
| ) |
| torch.onnx.export( |
| wrapper, (dummy_dec_ids, dummy_enc_out, dummy_past_kv), out_path, |
| input_names=["decoder_input_ids", "encoder_out", "past_self_kv"], |
| output_names=["logits", "present_self_kv"], |
| dynamic_axes={ |
| "decoder_input_ids": {0: "batch"}, |
| "encoder_out": {0: "batch", 1: "enc_seq"}, |
| "past_self_kv": {2: "batch", 4: "past_seq"}, |
| "logits": {0: "batch"}, |
| "present_self_kv": {2: "batch", 4: "present_seq"}, |
| }, |
| opset_version=17, do_constant_folding=True, external_data=False, dynamo=False, |
| ) |
| sz = out_path.stat().st_size |
| print(f"{out_path.name} written ({sz / 1e6:.1f} MB)") |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument("--state", default=str(ART / "needle_torch.pt"), |
| help="PyTorch state_dict produced by convert_weights.py") |
| p.add_argument("--config", default=str(ART / "needle_torch.config.json"), |
| help="Config JSON produced by convert_weights.py (same dim shape as the source ckpt)") |
| p.add_argument("--encoder-out", default=str(ART / "encoder.onnx")) |
| p.add_argument("--decoder-out", default=str(ART / "decoder_step.onnx")) |
| args = p.parse_args() |
|
|
| m, cfg = load_pt_model(Path(args.state), Path(args.config)) |
| export_encoder_to(m, Path(args.encoder_out)) |
| export_decoder_to(m, cfg, Path(args.decoder_out)) |
|
|