Spaces:
Sleeping
Sleeping
File size: 7,060 Bytes
f4b6528 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
ClauseGuard β ONNX Export + INT8 Quantization Pipeline (v2)
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
PERF v4.3: Full pipeline to export the CUAD LoRA classifier to ONNX+INT8.
Steps:
1. Load base Legal-BERT + LoRA adapter
2. merge_and_unload() β plain PreTrainedModel
3. Export to ONNX via optimum
4. Dynamic INT8 quantization (no calibration data needed)
5. Push quantized model to HuggingFace Hub
Usage:
pip install "optimum[onnxruntime]" peft transformers torch
python export_onnx_v2.py
# Or with custom paths:
HUB_MODEL_ID=gaurv007/clauseguard-onnx-int8 python export_onnx_v2.py
Hardware: Any CPU (no GPU needed for export)
Time: ~2-5 minutes
"""
import os
import sys
import shutil
# ββ Configuration ββ
BASE_MODEL = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased")
ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "Mokshith31/legalbert-contract-clause-classification")
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-onnx-int8")
PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
MERGED_DIR = "./merged_legalbert"
ONNX_DIR = "./onnx_legalbert"
QUANT_DIR = "./onnx_legalbert_int8"
CUAD_LABELS = [
"Document Name", "Parties", "Agreement Date", "Effective Date",
"Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
"Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
"No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
"Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
"Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
"Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
"Joint IP Ownership", "License Grant", "Non-Transferable License",
"Affiliate License-Licensor", "Affiliate License-Licensee",
"Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
"Source Code Escrow", "Post-Termination Services", "Audit Rights",
"Uncapped Liability", "Cap on Liability", "Liquidated Damages",
"Warranty Duration", "Insurance", "Covenant Not to Sue",
"Third Party Beneficiary", "Other",
]
def main():
print("π‘οΈ ClauseGuard ONNX Export + INT8 Quantization")
print("=" * 60)
print(f" Base model: {BASE_MODEL}")
print(f" LoRA adapter: {ADAPTER_MODEL}")
print(f" Hub target: {HUB_MODEL_ID}")
print()
# ββ Step 1: Load and merge LoRA ββ
print("π¦ Step 1: Loading base model + LoRA adapter...")
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL, num_labels=41, ignore_mismatched_sizes=True
)
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
print("π Step 2: Merging LoRA weights into base model...")
merged_model = peft_model.merge_and_unload(safe_merge=True)
# Set label mapping
merged_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
merged_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
os.makedirs(MERGED_DIR, exist_ok=True)
merged_model.save_pretrained(MERGED_DIR)
tokenizer.save_pretrained(MERGED_DIR)
print(f" β
Merged model saved to {MERGED_DIR}")
# Free memory
del peft_model, base_model, merged_model
import gc
gc.collect()
# ββ Step 3: Export to ONNX ββ
print("\nπ€ Step 3: Exporting to ONNX...")
from optimum.onnxruntime import ORTModelForSequenceClassification
ort_model = ORTModelForSequenceClassification.from_pretrained(
MERGED_DIR, export=True
)
os.makedirs(ONNX_DIR, exist_ok=True)
ort_model.save_pretrained(ONNX_DIR)
tokenizer.save_pretrained(ONNX_DIR)
print(f" β
ONNX model saved to {ONNX_DIR}")
# ββ Step 4: Dynamic INT8 Quantization ββ
print("\nβ‘ Step 4: Applying dynamic INT8 quantization...")
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTQuantizer
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
quantizer = ORTQuantizer.from_pretrained(ort_model)
os.makedirs(QUANT_DIR, exist_ok=True)
quantizer.quantize(save_dir=QUANT_DIR, quantization_config=qconfig)
# Copy tokenizer files to quantized dir
tokenizer.save_pretrained(QUANT_DIR)
# Copy config.json too
shutil.copy2(os.path.join(ONNX_DIR, "config.json"), QUANT_DIR)
print(f" β
Quantized model saved to {QUANT_DIR}")
# ββ Step 5: Verify ββ
print("\nπ§ͺ Step 5: Verifying quantized model...")
quant_model = ORTModelForSequenceClassification.from_pretrained(
QUANT_DIR, file_name="model_quantized.onnx"
)
quant_tokenizer = AutoTokenizer.from_pretrained(QUANT_DIR)
test_texts = [
"The company may terminate your account at any time without notice.",
"Either party shall indemnify and hold harmless the other party.",
"This Agreement shall be governed by the laws of the State of Delaware.",
]
inputs = quant_tokenizer(test_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
import torch
with torch.no_grad():
outputs = quant_model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
for i, text in enumerate(test_texts):
top_prob, top_idx = torch.max(probs[i], dim=0)
label = CUAD_LABELS[int(top_idx)] if int(top_idx) < len(CUAD_LABELS) else f"Class-{int(top_idx)}"
print(f" Text: {text[:60]}...")
print(f" β {label} ({top_prob:.3f})")
# ββ Step 6: Push to Hub ββ
if PUSH_TO_HUB:
print(f"\nπ Step 6: Pushing to {HUB_MODEL_ID}...")
quant_model.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
quant_tokenizer.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
print(f" β
Pushed to https://huggingface.co/{HUB_MODEL_ID}")
else:
print("\nβοΈ Skipping Hub push (PUSH_TO_HUB=false)")
# ββ Summary ββ
onnx_size = os.path.getsize(os.path.join(ONNX_DIR, "model.onnx")) / 1e6
quant_size = os.path.getsize(os.path.join(QUANT_DIR, "model_quantized.onnx")) / 1e6
print(f"\n{'='*60}")
print(f" π ONNX model size: {onnx_size:.1f} MB")
print(f" π Quantized model size: {quant_size:.1f} MB")
print(f" π Size reduction: {(1 - quant_size/onnx_size)*100:.0f}%")
print(f" π₯ Expected speedup: 2-4x on CPU")
print(f"{'='*60}")
print("\nβ
Export complete!")
print(f"\nTo use in ClauseGuard, set ONNX_MODEL_PATH={QUANT_DIR}")
print("or point to the Hub model: gaurv007/clauseguard-onnx-int8")
if __name__ == "__main__":
main()
|