gaurv007 commited on
Commit
ad221bd
Β·
verified Β·
1 Parent(s): 25234d2

v4.3 perf: Update ml/export_onnx_v2.py

Browse files
Files changed (1) hide show
  1. ml/export_onnx_v2.py +169 -0
ml/export_onnx_v2.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard β€” ONNX Export + INT8 Quantization Pipeline (v2)
3
+ ═══════════════════════════════════════════════════════════
4
+ PERF v4.3: Full pipeline to export the CUAD LoRA classifier to ONNX+INT8.
5
+
6
+ Steps:
7
+ 1. Load base Legal-BERT + LoRA adapter
8
+ 2. merge_and_unload() β†’ plain PreTrainedModel
9
+ 3. Export to ONNX via optimum
10
+ 4. Dynamic INT8 quantization (no calibration data needed)
11
+ 5. Push quantized model to HuggingFace Hub
12
+
13
+ Usage:
14
+ pip install "optimum[onnxruntime]" peft transformers torch
15
+ python export_onnx_v2.py
16
+
17
+ # Or with custom paths:
18
+ HUB_MODEL_ID=gaurv007/clauseguard-onnx-int8 python export_onnx_v2.py
19
+
20
+ Hardware: Any CPU (no GPU needed for export)
21
+ Time: ~2-5 minutes
22
+ """
23
+
24
+ import os
25
+ import sys
26
+ import shutil
27
+
28
+ # ── Configuration ──
29
+ BASE_MODEL = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased")
30
+ ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "Mokshith31/legalbert-contract-clause-classification")
31
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-onnx-int8")
32
+ PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
33
+
34
+ MERGED_DIR = "./merged_legalbert"
35
+ ONNX_DIR = "./onnx_legalbert"
36
+ QUANT_DIR = "./onnx_legalbert_int8"
37
+
38
+ CUAD_LABELS = [
39
+ "Document Name", "Parties", "Agreement Date", "Effective Date",
40
+ "Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
41
+ "Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
42
+ "No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
43
+ "Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
44
+ "Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
45
+ "Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
46
+ "Joint IP Ownership", "License Grant", "Non-Transferable License",
47
+ "Affiliate License-Licensor", "Affiliate License-Licensee",
48
+ "Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
49
+ "Source Code Escrow", "Post-Termination Services", "Audit Rights",
50
+ "Uncapped Liability", "Cap on Liability", "Liquidated Damages",
51
+ "Warranty Duration", "Insurance", "Covenant Not to Sue",
52
+ "Third Party Beneficiary", "Other",
53
+ ]
54
+
55
+
56
+ def main():
57
+ print("πŸ›‘οΈ ClauseGuard ONNX Export + INT8 Quantization")
58
+ print("=" * 60)
59
+ print(f" Base model: {BASE_MODEL}")
60
+ print(f" LoRA adapter: {ADAPTER_MODEL}")
61
+ print(f" Hub target: {HUB_MODEL_ID}")
62
+ print()
63
+
64
+ # ── Step 1: Load and merge LoRA ──
65
+ print("πŸ“¦ Step 1: Loading base model + LoRA adapter...")
66
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
67
+ from peft import PeftModel
68
+
69
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
70
+ base_model = AutoModelForSequenceClassification.from_pretrained(
71
+ BASE_MODEL, num_labels=41, ignore_mismatched_sizes=True
72
+ )
73
+ peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
74
+
75
+ print("πŸ”€ Step 2: Merging LoRA weights into base model...")
76
+ merged_model = peft_model.merge_and_unload(safe_merge=True)
77
+
78
+ # Set label mapping
79
+ merged_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
80
+ merged_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
81
+
82
+ os.makedirs(MERGED_DIR, exist_ok=True)
83
+ merged_model.save_pretrained(MERGED_DIR)
84
+ tokenizer.save_pretrained(MERGED_DIR)
85
+ print(f" βœ… Merged model saved to {MERGED_DIR}")
86
+
87
+ # Free memory
88
+ del peft_model, base_model, merged_model
89
+ import gc
90
+ gc.collect()
91
+
92
+ # ── Step 3: Export to ONNX ──
93
+ print("\nπŸ“€ Step 3: Exporting to ONNX...")
94
+ from optimum.onnxruntime import ORTModelForSequenceClassification
95
+
96
+ ort_model = ORTModelForSequenceClassification.from_pretrained(
97
+ MERGED_DIR, export=True
98
+ )
99
+ os.makedirs(ONNX_DIR, exist_ok=True)
100
+ ort_model.save_pretrained(ONNX_DIR)
101
+ tokenizer.save_pretrained(ONNX_DIR)
102
+ print(f" βœ… ONNX model saved to {ONNX_DIR}")
103
+
104
+ # ── Step 4: Dynamic INT8 Quantization ──
105
+ print("\n⚑ Step 4: Applying dynamic INT8 quantization...")
106
+ from optimum.onnxruntime.configuration import AutoQuantizationConfig
107
+ from optimum.onnxruntime import ORTQuantizer
108
+
109
+ qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
110
+ quantizer = ORTQuantizer.from_pretrained(ort_model)
111
+ os.makedirs(QUANT_DIR, exist_ok=True)
112
+ quantizer.quantize(save_dir=QUANT_DIR, quantization_config=qconfig)
113
+
114
+ # Copy tokenizer files to quantized dir
115
+ tokenizer.save_pretrained(QUANT_DIR)
116
+ # Copy config.json too
117
+ shutil.copy2(os.path.join(ONNX_DIR, "config.json"), QUANT_DIR)
118
+ print(f" βœ… Quantized model saved to {QUANT_DIR}")
119
+
120
+ # ── Step 5: Verify ──
121
+ print("\nπŸ§ͺ Step 5: Verifying quantized model...")
122
+ quant_model = ORTModelForSequenceClassification.from_pretrained(
123
+ QUANT_DIR, file_name="model_quantized.onnx"
124
+ )
125
+ quant_tokenizer = AutoTokenizer.from_pretrained(QUANT_DIR)
126
+
127
+ test_texts = [
128
+ "The company may terminate your account at any time without notice.",
129
+ "Either party shall indemnify and hold harmless the other party.",
130
+ "This Agreement shall be governed by the laws of the State of Delaware.",
131
+ ]
132
+ inputs = quant_tokenizer(test_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
133
+
134
+ import torch
135
+ with torch.no_grad():
136
+ outputs = quant_model(**inputs)
137
+ probs = torch.softmax(outputs.logits, dim=-1)
138
+
139
+ for i, text in enumerate(test_texts):
140
+ top_prob, top_idx = torch.max(probs[i], dim=0)
141
+ label = CUAD_LABELS[int(top_idx)] if int(top_idx) < len(CUAD_LABELS) else f"Class-{int(top_idx)}"
142
+ print(f" Text: {text[:60]}...")
143
+ print(f" β†’ {label} ({top_prob:.3f})")
144
+
145
+ # ── Step 6: Push to Hub ──
146
+ if PUSH_TO_HUB:
147
+ print(f"\nπŸš€ Step 6: Pushing to {HUB_MODEL_ID}...")
148
+ quant_model.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
149
+ quant_tokenizer.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
150
+ print(f" βœ… Pushed to https://huggingface.co/{HUB_MODEL_ID}")
151
+ else:
152
+ print("\n⏭️ Skipping Hub push (PUSH_TO_HUB=false)")
153
+
154
+ # ── Summary ──
155
+ onnx_size = os.path.getsize(os.path.join(ONNX_DIR, "model.onnx")) / 1e6
156
+ quant_size = os.path.getsize(os.path.join(QUANT_DIR, "model_quantized.onnx")) / 1e6
157
+ print(f"\n{'='*60}")
158
+ print(f" πŸ“Š ONNX model size: {onnx_size:.1f} MB")
159
+ print(f" πŸ“Š Quantized model size: {quant_size:.1f} MB")
160
+ print(f" πŸ“Š Size reduction: {(1 - quant_size/onnx_size)*100:.0f}%")
161
+ print(f" πŸ”₯ Expected speedup: 2-4x on CPU")
162
+ print(f"{'='*60}")
163
+ print("\nβœ… Export complete!")
164
+ print(f"\nTo use in ClauseGuard, set ONNX_MODEL_PATH={QUANT_DIR}")
165
+ print("or point to the Hub model: gaurv007/clauseguard-onnx-int8")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()