ClauseGuard / ml /export_onnx_v2.py
gaurv007's picture
⚑ v4.3: Performance optimizations β€” ONNX INT8, BGE embedder, batched classification, thread control (#4)
f4b6528
"""
ClauseGuard β€” ONNX Export + INT8 Quantization Pipeline (v2)
═══════════════════════════════════════════════════════════
PERF v4.3: Full pipeline to export the CUAD LoRA classifier to ONNX+INT8.
Steps:
1. Load base Legal-BERT + LoRA adapter
2. merge_and_unload() β†’ plain PreTrainedModel
3. Export to ONNX via optimum
4. Dynamic INT8 quantization (no calibration data needed)
5. Push quantized model to HuggingFace Hub
Usage:
pip install "optimum[onnxruntime]" peft transformers torch
python export_onnx_v2.py
# Or with custom paths:
HUB_MODEL_ID=gaurv007/clauseguard-onnx-int8 python export_onnx_v2.py
Hardware: Any CPU (no GPU needed for export)
Time: ~2-5 minutes
"""
import os
import sys
import shutil
# ── Configuration ──
BASE_MODEL = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased")
ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "Mokshith31/legalbert-contract-clause-classification")
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-onnx-int8")
PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
MERGED_DIR = "./merged_legalbert"
ONNX_DIR = "./onnx_legalbert"
QUANT_DIR = "./onnx_legalbert_int8"
CUAD_LABELS = [
"Document Name", "Parties", "Agreement Date", "Effective Date",
"Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
"Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
"No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
"Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
"Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
"Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
"Joint IP Ownership", "License Grant", "Non-Transferable License",
"Affiliate License-Licensor", "Affiliate License-Licensee",
"Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
"Source Code Escrow", "Post-Termination Services", "Audit Rights",
"Uncapped Liability", "Cap on Liability", "Liquidated Damages",
"Warranty Duration", "Insurance", "Covenant Not to Sue",
"Third Party Beneficiary", "Other",
]
def main():
print("πŸ›‘οΈ ClauseGuard ONNX Export + INT8 Quantization")
print("=" * 60)
print(f" Base model: {BASE_MODEL}")
print(f" LoRA adapter: {ADAPTER_MODEL}")
print(f" Hub target: {HUB_MODEL_ID}")
print()
# ── Step 1: Load and merge LoRA ──
print("πŸ“¦ Step 1: Loading base model + LoRA adapter...")
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL, num_labels=41, ignore_mismatched_sizes=True
)
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
print("πŸ”€ Step 2: Merging LoRA weights into base model...")
merged_model = peft_model.merge_and_unload(safe_merge=True)
# Set label mapping
merged_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
merged_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
os.makedirs(MERGED_DIR, exist_ok=True)
merged_model.save_pretrained(MERGED_DIR)
tokenizer.save_pretrained(MERGED_DIR)
print(f" βœ… Merged model saved to {MERGED_DIR}")
# Free memory
del peft_model, base_model, merged_model
import gc
gc.collect()
# ── Step 3: Export to ONNX ──
print("\nπŸ“€ Step 3: Exporting to ONNX...")
from optimum.onnxruntime import ORTModelForSequenceClassification
ort_model = ORTModelForSequenceClassification.from_pretrained(
MERGED_DIR, export=True
)
os.makedirs(ONNX_DIR, exist_ok=True)
ort_model.save_pretrained(ONNX_DIR)
tokenizer.save_pretrained(ONNX_DIR)
print(f" βœ… ONNX model saved to {ONNX_DIR}")
# ── Step 4: Dynamic INT8 Quantization ──
print("\n⚑ Step 4: Applying dynamic INT8 quantization...")
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTQuantizer
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
quantizer = ORTQuantizer.from_pretrained(ort_model)
os.makedirs(QUANT_DIR, exist_ok=True)
quantizer.quantize(save_dir=QUANT_DIR, quantization_config=qconfig)
# Copy tokenizer files to quantized dir
tokenizer.save_pretrained(QUANT_DIR)
# Copy config.json too
shutil.copy2(os.path.join(ONNX_DIR, "config.json"), QUANT_DIR)
print(f" βœ… Quantized model saved to {QUANT_DIR}")
# ── Step 5: Verify ──
print("\nπŸ§ͺ Step 5: Verifying quantized model...")
quant_model = ORTModelForSequenceClassification.from_pretrained(
QUANT_DIR, file_name="model_quantized.onnx"
)
quant_tokenizer = AutoTokenizer.from_pretrained(QUANT_DIR)
test_texts = [
"The company may terminate your account at any time without notice.",
"Either party shall indemnify and hold harmless the other party.",
"This Agreement shall be governed by the laws of the State of Delaware.",
]
inputs = quant_tokenizer(test_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
import torch
with torch.no_grad():
outputs = quant_model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
for i, text in enumerate(test_texts):
top_prob, top_idx = torch.max(probs[i], dim=0)
label = CUAD_LABELS[int(top_idx)] if int(top_idx) < len(CUAD_LABELS) else f"Class-{int(top_idx)}"
print(f" Text: {text[:60]}...")
print(f" β†’ {label} ({top_prob:.3f})")
# ── Step 6: Push to Hub ──
if PUSH_TO_HUB:
print(f"\nπŸš€ Step 6: Pushing to {HUB_MODEL_ID}...")
quant_model.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
quant_tokenizer.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
print(f" βœ… Pushed to https://huggingface.co/{HUB_MODEL_ID}")
else:
print("\n⏭️ Skipping Hub push (PUSH_TO_HUB=false)")
# ── Summary ──
onnx_size = os.path.getsize(os.path.join(ONNX_DIR, "model.onnx")) / 1e6
quant_size = os.path.getsize(os.path.join(QUANT_DIR, "model_quantized.onnx")) / 1e6
print(f"\n{'='*60}")
print(f" πŸ“Š ONNX model size: {onnx_size:.1f} MB")
print(f" πŸ“Š Quantized model size: {quant_size:.1f} MB")
print(f" πŸ“Š Size reduction: {(1 - quant_size/onnx_size)*100:.0f}%")
print(f" πŸ”₯ Expected speedup: 2-4x on CPU")
print(f"{'='*60}")
print("\nβœ… Export complete!")
print(f"\nTo use in ClauseGuard, set ONNX_MODEL_PATH={QUANT_DIR}")
print("or point to the Hub model: gaurv007/clauseguard-onnx-int8")
if __name__ == "__main__":
main()