Spaces:
Sleeping
Sleeping
| """ | |
| ClauseGuard β ONNX Export + INT8 Quantization Pipeline (v2) | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PERF v4.3: Full pipeline to export the CUAD LoRA classifier to ONNX+INT8. | |
| Steps: | |
| 1. Load base Legal-BERT + LoRA adapter | |
| 2. merge_and_unload() β plain PreTrainedModel | |
| 3. Export to ONNX via optimum | |
| 4. Dynamic INT8 quantization (no calibration data needed) | |
| 5. Push quantized model to HuggingFace Hub | |
| Usage: | |
| pip install "optimum[onnxruntime]" peft transformers torch | |
| python export_onnx_v2.py | |
| # Or with custom paths: | |
| HUB_MODEL_ID=gaurv007/clauseguard-onnx-int8 python export_onnx_v2.py | |
| Hardware: Any CPU (no GPU needed for export) | |
| Time: ~2-5 minutes | |
| """ | |
| import os | |
| import sys | |
| import shutil | |
| # ββ Configuration ββ | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased") | |
| ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "Mokshith31/legalbert-contract-clause-classification") | |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-onnx-int8") | |
| PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true" | |
| MERGED_DIR = "./merged_legalbert" | |
| ONNX_DIR = "./onnx_legalbert" | |
| QUANT_DIR = "./onnx_legalbert_int8" | |
| CUAD_LABELS = [ | |
| "Document Name", "Parties", "Agreement Date", "Effective Date", | |
| "Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal", | |
| "Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity", | |
| "No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement", | |
| "Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control", | |
| "Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction", | |
| "Minimum Commitment", "Volume Restriction", "IP Ownership Assignment", | |
| "Joint IP Ownership", "License Grant", "Non-Transferable License", | |
| "Affiliate License-Licensor", "Affiliate License-Licensee", | |
| "Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License", | |
| "Source Code Escrow", "Post-Termination Services", "Audit Rights", | |
| "Uncapped Liability", "Cap on Liability", "Liquidated Damages", | |
| "Warranty Duration", "Insurance", "Covenant Not to Sue", | |
| "Third Party Beneficiary", "Other", | |
| ] | |
| def main(): | |
| print("π‘οΈ ClauseGuard ONNX Export + INT8 Quantization") | |
| print("=" * 60) | |
| print(f" Base model: {BASE_MODEL}") | |
| print(f" LoRA adapter: {ADAPTER_MODEL}") | |
| print(f" Hub target: {HUB_MODEL_ID}") | |
| print() | |
| # ββ Step 1: Load and merge LoRA ββ | |
| print("π¦ Step 1: Loading base model + LoRA adapter...") | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from peft import PeftModel | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL, num_labels=41, ignore_mismatched_sizes=True | |
| ) | |
| peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL) | |
| print("π Step 2: Merging LoRA weights into base model...") | |
| merged_model = peft_model.merge_and_unload(safe_merge=True) | |
| # Set label mapping | |
| merged_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)} | |
| merged_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)} | |
| os.makedirs(MERGED_DIR, exist_ok=True) | |
| merged_model.save_pretrained(MERGED_DIR) | |
| tokenizer.save_pretrained(MERGED_DIR) | |
| print(f" β Merged model saved to {MERGED_DIR}") | |
| # Free memory | |
| del peft_model, base_model, merged_model | |
| import gc | |
| gc.collect() | |
| # ββ Step 3: Export to ONNX ββ | |
| print("\nπ€ Step 3: Exporting to ONNX...") | |
| from optimum.onnxruntime import ORTModelForSequenceClassification | |
| ort_model = ORTModelForSequenceClassification.from_pretrained( | |
| MERGED_DIR, export=True | |
| ) | |
| os.makedirs(ONNX_DIR, exist_ok=True) | |
| ort_model.save_pretrained(ONNX_DIR) | |
| tokenizer.save_pretrained(ONNX_DIR) | |
| print(f" β ONNX model saved to {ONNX_DIR}") | |
| # ββ Step 4: Dynamic INT8 Quantization ββ | |
| print("\nβ‘ Step 4: Applying dynamic INT8 quantization...") | |
| from optimum.onnxruntime.configuration import AutoQuantizationConfig | |
| from optimum.onnxruntime import ORTQuantizer | |
| qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False) | |
| quantizer = ORTQuantizer.from_pretrained(ort_model) | |
| os.makedirs(QUANT_DIR, exist_ok=True) | |
| quantizer.quantize(save_dir=QUANT_DIR, quantization_config=qconfig) | |
| # Copy tokenizer files to quantized dir | |
| tokenizer.save_pretrained(QUANT_DIR) | |
| # Copy config.json too | |
| shutil.copy2(os.path.join(ONNX_DIR, "config.json"), QUANT_DIR) | |
| print(f" β Quantized model saved to {QUANT_DIR}") | |
| # ββ Step 5: Verify ββ | |
| print("\nπ§ͺ Step 5: Verifying quantized model...") | |
| quant_model = ORTModelForSequenceClassification.from_pretrained( | |
| QUANT_DIR, file_name="model_quantized.onnx" | |
| ) | |
| quant_tokenizer = AutoTokenizer.from_pretrained(QUANT_DIR) | |
| test_texts = [ | |
| "The company may terminate your account at any time without notice.", | |
| "Either party shall indemnify and hold harmless the other party.", | |
| "This Agreement shall be governed by the laws of the State of Delaware.", | |
| ] | |
| inputs = quant_tokenizer(test_texts, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| import torch | |
| with torch.no_grad(): | |
| outputs = quant_model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| for i, text in enumerate(test_texts): | |
| top_prob, top_idx = torch.max(probs[i], dim=0) | |
| label = CUAD_LABELS[int(top_idx)] if int(top_idx) < len(CUAD_LABELS) else f"Class-{int(top_idx)}" | |
| print(f" Text: {text[:60]}...") | |
| print(f" β {label} ({top_prob:.3f})") | |
| # ββ Step 6: Push to Hub ββ | |
| if PUSH_TO_HUB: | |
| print(f"\nπ Step 6: Pushing to {HUB_MODEL_ID}...") | |
| quant_model.push_to_hub(HUB_MODEL_ID, use_auth_token=True) | |
| quant_tokenizer.push_to_hub(HUB_MODEL_ID, use_auth_token=True) | |
| print(f" β Pushed to https://huggingface.co/{HUB_MODEL_ID}") | |
| else: | |
| print("\nβοΈ Skipping Hub push (PUSH_TO_HUB=false)") | |
| # ββ Summary ββ | |
| onnx_size = os.path.getsize(os.path.join(ONNX_DIR, "model.onnx")) / 1e6 | |
| quant_size = os.path.getsize(os.path.join(QUANT_DIR, "model_quantized.onnx")) / 1e6 | |
| print(f"\n{'='*60}") | |
| print(f" π ONNX model size: {onnx_size:.1f} MB") | |
| print(f" π Quantized model size: {quant_size:.1f} MB") | |
| print(f" π Size reduction: {(1 - quant_size/onnx_size)*100:.0f}%") | |
| print(f" π₯ Expected speedup: 2-4x on CPU") | |
| print(f"{'='*60}") | |
| print("\nβ Export complete!") | |
| print(f"\nTo use in ClauseGuard, set ONNX_MODEL_PATH={QUANT_DIR}") | |
| print("or point to the Hub model: gaurv007/clauseguard-onnx-int8") | |
| if __name__ == "__main__": | |
| main() | |