File size: 2,128 Bytes
bccd4cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | # /// script
# requires-python = ">=3.12"
# dependencies = [
# "accelerate>=1.13.0",
# "flash-linear-attention>=0.4.2",
# "hf-xet>=1.4.3",
# "huggingface-hub>=1.8.0",
# "onnx>=1.21.0",
# "onnx-ir>=0.2.0",
# "onnxruntime>=1.24.4",
# "onnxruntime-genai>=0.13.1",
# "optimum>=2.1.0",
# "sentencepiece>=0.2.1",
# "tiktoken>=0.12.0",
# "torch>=2.11.0",
# "transformers==5.7.0",
# ]
# ///
import argparse
from pathlib import Path
from huggingface_hub import snapshot_download
#from onnxruntime_genai.python.models.builder import create_model
from onnxruntime_genai.models.builder import create_model
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--name", required=False,default=None)
parser.add_argument("--token",required=False)
args = parser.parse_args()
token = args.token if args.token else None
pwd = Path.cwd()
model_dir = pwd / "model"
onnx_dir = pwd / "onnx"
cache_dir = pwd / "cache"
model_dir.mkdir(exist_ok=True)
onnx_dir.mkdir(exist_ok=True)
cache_dir.mkdir(exist_ok=True)
# ===== STEP 1: DOWNLOAD (HF HUB + XET backend automatically used) =====
print(">> Downloading model via huggingface_hub (Xet enabled if installed)...")
# local_path = snapshot_download(
# repo_id=args.name,
# local_dir=str(model_dir),
# token=token
# #local_dir_use_symlinks=False # important for ONNX tools
# )
#print(f"Model downloaded to: {local_path}")
# ===== STEP 2: CONVERT USING ONNX GENAI BUILDER =====
print(">> Converting to ONNX (GenAI format)...")
create_model(
model_name=args.name,
input_path=str(model_dir), # HF model directory
output_dir=str(onnx_dir), # ONNX output
precision="fp16", # fp32 | fp16 | int8 | int4 (if supported)
execution_provider="cpu", # cpu | cuda | dml
cache_dir=str(cache_dir), # optional cache
extra_options={}
)
print("\n✅ Done")
print(f"ONNX model at: {onnx_dir}")
if __name__ == "__main__":
main()
|