File size: 7,060 Bytes
f4b6528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
ClauseGuard β€” ONNX Export + INT8 Quantization Pipeline (v2)
═══════════════════════════════════════════════════════════
PERF v4.3: Full pipeline to export the CUAD LoRA classifier to ONNX+INT8.

Steps:
  1. Load base Legal-BERT + LoRA adapter
  2. merge_and_unload() β†’ plain PreTrainedModel
  3. Export to ONNX via optimum
  4. Dynamic INT8 quantization (no calibration data needed)
  5. Push quantized model to HuggingFace Hub

Usage:
    pip install "optimum[onnxruntime]" peft transformers torch
    python export_onnx_v2.py

    # Or with custom paths:
    HUB_MODEL_ID=gaurv007/clauseguard-onnx-int8 python export_onnx_v2.py

Hardware: Any CPU (no GPU needed for export)
Time: ~2-5 minutes
"""

import os
import sys
import shutil

# ── Configuration ──
BASE_MODEL = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased")
ADAPTER_MODEL = os.environ.get("ADAPTER_MODEL", "Mokshith31/legalbert-contract-clause-classification")
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-onnx-int8")
PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"

MERGED_DIR = "./merged_legalbert"
ONNX_DIR = "./onnx_legalbert"
QUANT_DIR = "./onnx_legalbert_int8"

CUAD_LABELS = [
    "Document Name", "Parties", "Agreement Date", "Effective Date",
    "Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
    "Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
    "No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
    "Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
    "Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
    "Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
    "Joint IP Ownership", "License Grant", "Non-Transferable License",
    "Affiliate License-Licensor", "Affiliate License-Licensee",
    "Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
    "Source Code Escrow", "Post-Termination Services", "Audit Rights",
    "Uncapped Liability", "Cap on Liability", "Liquidated Damages",
    "Warranty Duration", "Insurance", "Covenant Not to Sue",
    "Third Party Beneficiary", "Other",
]


def main():
    print("πŸ›‘οΈ  ClauseGuard ONNX Export + INT8 Quantization")
    print("=" * 60)
    print(f"   Base model:   {BASE_MODEL}")
    print(f"   LoRA adapter: {ADAPTER_MODEL}")
    print(f"   Hub target:   {HUB_MODEL_ID}")
    print()

    # ── Step 1: Load and merge LoRA ──
    print("πŸ“¦ Step 1: Loading base model + LoRA adapter...")
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    from peft import PeftModel

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    base_model = AutoModelForSequenceClassification.from_pretrained(
        BASE_MODEL, num_labels=41, ignore_mismatched_sizes=True
    )
    peft_model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)

    print("πŸ”€ Step 2: Merging LoRA weights into base model...")
    merged_model = peft_model.merge_and_unload(safe_merge=True)

    # Set label mapping
    merged_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
    merged_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}

    os.makedirs(MERGED_DIR, exist_ok=True)
    merged_model.save_pretrained(MERGED_DIR)
    tokenizer.save_pretrained(MERGED_DIR)
    print(f"   βœ… Merged model saved to {MERGED_DIR}")

    # Free memory
    del peft_model, base_model, merged_model
    import gc
    gc.collect()

    # ── Step 3: Export to ONNX ──
    print("\nπŸ“€ Step 3: Exporting to ONNX...")
    from optimum.onnxruntime import ORTModelForSequenceClassification

    ort_model = ORTModelForSequenceClassification.from_pretrained(
        MERGED_DIR, export=True
    )
    os.makedirs(ONNX_DIR, exist_ok=True)
    ort_model.save_pretrained(ONNX_DIR)
    tokenizer.save_pretrained(ONNX_DIR)
    print(f"   βœ… ONNX model saved to {ONNX_DIR}")

    # ── Step 4: Dynamic INT8 Quantization ──
    print("\n⚑ Step 4: Applying dynamic INT8 quantization...")
    from optimum.onnxruntime.configuration import AutoQuantizationConfig
    from optimum.onnxruntime import ORTQuantizer

    qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
    quantizer = ORTQuantizer.from_pretrained(ort_model)
    os.makedirs(QUANT_DIR, exist_ok=True)
    quantizer.quantize(save_dir=QUANT_DIR, quantization_config=qconfig)

    # Copy tokenizer files to quantized dir
    tokenizer.save_pretrained(QUANT_DIR)
    # Copy config.json too
    shutil.copy2(os.path.join(ONNX_DIR, "config.json"), QUANT_DIR)
    print(f"   βœ… Quantized model saved to {QUANT_DIR}")

    # ── Step 5: Verify ──
    print("\nπŸ§ͺ Step 5: Verifying quantized model...")
    quant_model = ORTModelForSequenceClassification.from_pretrained(
        QUANT_DIR, file_name="model_quantized.onnx"
    )
    quant_tokenizer = AutoTokenizer.from_pretrained(QUANT_DIR)

    test_texts = [
        "The company may terminate your account at any time without notice.",
        "Either party shall indemnify and hold harmless the other party.",
        "This Agreement shall be governed by the laws of the State of Delaware.",
    ]
    inputs = quant_tokenizer(test_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)

    import torch
    with torch.no_grad():
        outputs = quant_model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)

    for i, text in enumerate(test_texts):
        top_prob, top_idx = torch.max(probs[i], dim=0)
        label = CUAD_LABELS[int(top_idx)] if int(top_idx) < len(CUAD_LABELS) else f"Class-{int(top_idx)}"
        print(f"   Text: {text[:60]}...")
        print(f"   β†’ {label} ({top_prob:.3f})")

    # ── Step 6: Push to Hub ──
    if PUSH_TO_HUB:
        print(f"\nπŸš€ Step 6: Pushing to {HUB_MODEL_ID}...")
        quant_model.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
        quant_tokenizer.push_to_hub(HUB_MODEL_ID, use_auth_token=True)
        print(f"   βœ… Pushed to https://huggingface.co/{HUB_MODEL_ID}")
    else:
        print("\n⏭️  Skipping Hub push (PUSH_TO_HUB=false)")

    # ── Summary ──
    onnx_size = os.path.getsize(os.path.join(ONNX_DIR, "model.onnx")) / 1e6
    quant_size = os.path.getsize(os.path.join(QUANT_DIR, "model_quantized.onnx")) / 1e6
    print(f"\n{'='*60}")
    print(f"   πŸ“Š ONNX model size:      {onnx_size:.1f} MB")
    print(f"   πŸ“Š Quantized model size:  {quant_size:.1f} MB")
    print(f"   πŸ“Š Size reduction:        {(1 - quant_size/onnx_size)*100:.0f}%")
    print(f"   πŸ”₯ Expected speedup:      2-4x on CPU")
    print(f"{'='*60}")
    print("\nβœ… Export complete!")
    print(f"\nTo use in ClauseGuard, set ONNX_MODEL_PATH={QUANT_DIR}")
    print("or point to the Hub model: gaurv007/clauseguard-onnx-int8")


if __name__ == "__main__":
    main()