Spaces:
Running
Running
v4.3 perf: Update ml/export_onnx_v2.py
Browse files- ml/export_onnx_v2.py +169 -0
ml/export_onnx_v2.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ClauseGuard β ONNX Export + INT8 Quantization Pipeline (v2)
|
| 3 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
PERF v4.3: Full pipeline to export the CUAD LoRA classifier to ONNX+INT8.
|
| 5 |
+
|
| 6 |
+
Steps:
|
| 7 |
+
1. Load base Legal-BERT + LoRA adapter
|
| 8 |
+
2. merge_and_unload() β plain PreTrainedModel
|
| 9 |
+
3. Export to ONNX via optimum
|
| 10 |
+
4. Dynamic INT8 quantization (no calibration data needed)
|
| 11 |
+
5. Push quantized model to HuggingFace Hub
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
pip install "optimum[onnxruntime]" peft transformers torch
|
| 15 |
+
python export_onnx_v2.py
|
| 16 |
+
|
| 17 |
+
# Or with custom paths:
|
| 18 |
+
HUB_MODEL_ID=gaurv007/clauseguard-onnx-int8 python export_onnx_v2.py
|
| 19 |
+
|
| 20 |
+
Hardware: Any CPU (no GPU needed for export)
|
| 21 |
+
Time: ~2-5 minutes
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
import shutil
|
| 27 |
+
|
| 28 |
+
# ββ Configuration ββ
|
| 29 |
+
BASE_MODEL = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased")
|
| 30 |
+
ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "Mokshith31/legalbert-contract-clause-classification")
|
| 31 |
+
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-onnx-int8")
|
| 32 |
+
PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
|
| 33 |
+
|
| 34 |
+
MERGED_DIR = "./merged_legalbert"
|
| 35 |
+
ONNX_DIR = "./onnx_legalbert"
|
| 36 |
+
QUANT_DIR = "./onnx_legalbert_int8"
|
| 37 |
+
|
| 38 |
+
CUAD_LABELS = [
|
| 39 |
+
"Document Name", "Parties", "Agreement Date", "Effective Date",
|
| 40 |
+
"Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
|
| 41 |
+
"Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
|
| 42 |
+
"No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
|
| 43 |
+
"Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
|
| 44 |
+
"Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
|
| 45 |
+
"Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
|
| 46 |
+
"Joint IP Ownership", "License Grant", "Non-Transferable License",
|
| 47 |
+
"Affiliate License-Licensor", "Affiliate License-Licensee",
|
| 48 |
+
"Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
|
| 49 |
+
"Source Code Escrow", "Post-Termination Services", "Audit Rights",
|
| 50 |
+
"Uncapped Liability", "Cap on Liability", "Liquidated Damages",
|
| 51 |
+
"Warranty Duration", "Insurance", "Covenant Not to Sue",
|
| 52 |
+
"Third Party Beneficiary", "Other",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main():
|
| 57 |
+
print("π‘οΈ ClauseGuard ONNX Export + INT8 Quantization")
|
| 58 |
+
print("=" * 60)
|
| 59 |
+
print(f" Base model: {BASE_MODEL}")
|
| 60 |
+
print(f" LoRA adapter: {ADAPTER_MODEL}")
|
| 61 |
+
print(f" Hub target: {HUB_MODEL_ID}")
|
| 62 |
+
print()
|
| 63 |
+
|
| 64 |
+
# ββ Step 1: Load and merge LoRA ββ
|
| 65 |
+
print("π¦ Step 1: Loading base model + LoRA adapter...")
|
| 66 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 67 |
+
from peft import PeftModel
|
| 68 |
+
|
| 69 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 70 |
+
base_model = AutoModelForSequenceClassification.from_pretrained(
|
| 71 |
+
BASE_MODEL, num_labels=41, ignore_mismatched_sizes=True
|
| 72 |
+
)
|
| 73 |
+
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
|
| 74 |
+
|
| 75 |
+
print("π Step 2: Merging LoRA weights into base model...")
|
| 76 |
+
merged_model = peft_model.merge_and_unload(safe_merge=True)
|
| 77 |
+
|
| 78 |
+
# Set label mapping
|
| 79 |
+
merged_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
|
| 80 |
+
merged_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
|
| 81 |
+
|
| 82 |
+
os.makedirs(MERGED_DIR, exist_ok=True)
|
| 83 |
+
merged_model.save_pretrained(MERGED_DIR)
|
| 84 |
+
tokenizer.save_pretrained(MERGED_DIR)
|
| 85 |
+
print(f" β
Merged model saved to {MERGED_DIR}")
|
| 86 |
+
|
| 87 |
+
# Free memory
|
| 88 |
+
del peft_model, base_model, merged_model
|
| 89 |
+
import gc
|
| 90 |
+
gc.collect()
|
| 91 |
+
|
| 92 |
+
# ββ Step 3: Export to ONNX ββ
|
| 93 |
+
print("\nπ€ Step 3: Exporting to ONNX...")
|
| 94 |
+
from optimum.onnxruntime import ORTModelForSequenceClassification
|
| 95 |
+
|
| 96 |
+
ort_model = ORTModelForSequenceClassification.from_pretrained(
|
| 97 |
+
MERGED_DIR, export=True
|
| 98 |
+
)
|
| 99 |
+
os.makedirs(ONNX_DIR, exist_ok=True)
|
| 100 |
+
ort_model.save_pretrained(ONNX_DIR)
|
| 101 |
+
tokenizer.save_pretrained(ONNX_DIR)
|
| 102 |
+
print(f" β
ONNX model saved to {ONNX_DIR}")
|
| 103 |
+
|
| 104 |
+
# ββ Step 4: Dynamic INT8 Quantization ββ
|
| 105 |
+
print("\nβ‘ Step 4: Applying dynamic INT8 quantization...")
|
| 106 |
+
from optimum.onnxruntime.configuration import AutoQuantizationConfig
|
| 107 |
+
from optimum.onnxruntime import ORTQuantizer
|
| 108 |
+
|
| 109 |
+
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
|
| 110 |
+
quantizer = ORTQuantizer.from_pretrained(ort_model)
|
| 111 |
+
os.makedirs(QUANT_DIR, exist_ok=True)
|
| 112 |
+
quantizer.quantize(save_dir=QUANT_DIR, quantization_config=qconfig)
|
| 113 |
+
|
| 114 |
+
# Copy tokenizer files to quantized dir
|
| 115 |
+
tokenizer.save_pretrained(QUANT_DIR)
|
| 116 |
+
# Copy config.json too
|
| 117 |
+
shutil.copy2(os.path.join(ONNX_DIR, "config.json"), QUANT_DIR)
|
| 118 |
+
print(f" β
Quantized model saved to {QUANT_DIR}")
|
| 119 |
+
|
| 120 |
+
# ββ Step 5: Verify ββ
|
| 121 |
+
print("\nπ§ͺ Step 5: Verifying quantized model...")
|
| 122 |
+
quant_model = ORTModelForSequenceClassification.from_pretrained(
|
| 123 |
+
QUANT_DIR, file_name="model_quantized.onnx"
|
| 124 |
+
)
|
| 125 |
+
quant_tokenizer = AutoTokenizer.from_pretrained(QUANT_DIR)
|
| 126 |
+
|
| 127 |
+
test_texts = [
|
| 128 |
+
"The company may terminate your account at any time without notice.",
|
| 129 |
+
"Either party shall indemnify and hold harmless the other party.",
|
| 130 |
+
"This Agreement shall be governed by the laws of the State of Delaware.",
|
| 131 |
+
]
|
| 132 |
+
inputs = quant_tokenizer(test_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
| 133 |
+
|
| 134 |
+
import torch
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
outputs = quant_model(**inputs)
|
| 137 |
+
probs = torch.softmax(outputs.logits, dim=-1)
|
| 138 |
+
|
| 139 |
+
for i, text in enumerate(test_texts):
|
| 140 |
+
top_prob, top_idx = torch.max(probs[i], dim=0)
|
| 141 |
+
label = CUAD_LABELS[int(top_idx)] if int(top_idx) < len(CUAD_LABELS) else f"Class-{int(top_idx)}"
|
| 142 |
+
print(f" Text: {text[:60]}...")
|
| 143 |
+
print(f" β {label} ({top_prob:.3f})")
|
| 144 |
+
|
| 145 |
+
# ββ Step 6: Push to Hub ββ
|
| 146 |
+
if PUSH_TO_HUB:
|
| 147 |
+
print(f"\nπ Step 6: Pushing to {HUB_MODEL_ID}...")
|
| 148 |
+
quant_model.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
|
| 149 |
+
quant_tokenizer.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
|
| 150 |
+
print(f" β
Pushed to https://huggingface.co/{HUB_MODEL_ID}")
|
| 151 |
+
else:
|
| 152 |
+
print("\nβοΈ Skipping Hub push (PUSH_TO_HUB=false)")
|
| 153 |
+
|
| 154 |
+
# ββ Summary ββ
|
| 155 |
+
onnx_size = os.path.getsize(os.path.join(ONNX_DIR, "model.onnx")) / 1e6
|
| 156 |
+
quant_size = os.path.getsize(os.path.join(QUANT_DIR, "model_quantized.onnx")) / 1e6
|
| 157 |
+
print(f"\n{'='*60}")
|
| 158 |
+
print(f" π ONNX model size: {onnx_size:.1f} MB")
|
| 159 |
+
print(f" π Quantized model size: {quant_size:.1f} MB")
|
| 160 |
+
print(f" π Size reduction: {(1 - quant_size/onnx_size)*100:.0f}%")
|
| 161 |
+
print(f" π₯ Expected speedup: 2-4x on CPU")
|
| 162 |
+
print(f"{'='*60}")
|
| 163 |
+
print("\nβ
Export complete!")
|
| 164 |
+
print(f"\nTo use in ClauseGuard, set ONNX_MODEL_PATH={QUANT_DIR}")
|
| 165 |
+
print("or point to the Hub model: gaurv007/clauseguard-onnx-int8")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
main()
|