gaurv007 commited on
Commit
a4cb2c1
·
verified ·
1 Parent(s): 924af4e

Add ClauseGuard v4 Colab notebook: DeBERTa-v3-large 2-stage training (LEDGAR→CUAD) with ASL

Browse files
Files changed (1) hide show
  1. ml/ClauseGuard_DeBERTa_Training.ipynb +1042 -0
ml/ClauseGuard_DeBERTa_Training.ipynb ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "A100",
8
+ "machine_shape": "hm"
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "markdown",
22
+ "source": [
23
+ "# 🛡️ ClauseGuard v4 — DeBERTa-v3-large 2-Stage Training\n",
24
+ "\n",
25
+ "**Goal:** Train a production-grade contract clause classifier that replaces the current Legal-BERT-base (50% F1 → target 80-87% F1)\n",
26
+ "\n",
27
+ "## Architecture\n",
28
+ "| Setting | Value | Source |\n",
29
+ "|---------|-------|--------|\n",
30
+ "| Base model | `microsoft/deberta-v3-large` (435M params) | LexGLUE: outperforms Legal-BERT by 7-10pp |\n",
31
+ "| Max length | 512 tokens | MAUD paper: covers 72.4% of clauses without truncation |\n",
32
+ "| Loss function | Asymmetric Loss (γ-=4, clip=0.05) | ASL paper (2009.14119): +3-8pp on rare classes |\n",
33
+ "| Training | Full fine-tuning (no LoRA) | Full FT wins for encoder classification |\n",
34
+ "\n",
35
+ "## 2-Stage Training Pipeline\n",
36
+ "1. **Stage 1 — LEDGAR** (60K legal provisions, 100 classes): Teaches \"what types of contract clauses exist\"\n",
37
+ "2. **Stage 2 — CUAD** (41 CUAD classes): Target task with Asymmetric Loss for class imbalance\n",
38
+ "\n",
39
+ "**Runtime:** ~4-6 hours on A100 GPU\n",
40
+ "\n",
41
+ "**Before running:**\n",
42
+ "1. `Runtime` → `Change runtime type` → **A100 GPU** (High-RAM if available)\n",
43
+ "2. `Runtime` → `Run all`\n",
44
+ "3. Paste your HuggingFace token when prompted"
45
+ ],
46
+ "metadata": {}
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "source": [
51
+ "## Step 1: Install Dependencies"
52
+ ],
53
+ "metadata": {}
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "source": [
58
+ "!pip install -q transformers datasets scikit-learn accelerate huggingface_hub torch\n",
59
+ "!pip install -q trackio # optional: experiment tracking"
60
+ ],
61
+ "metadata": {},
62
+ "execution_count": null,
63
+ "outputs": []
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "source": [
68
+ "## Step 2: Login to HuggingFace Hub"
69
+ ],
70
+ "metadata": {}
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "source": [
75
+ "from huggingface_hub import login\n",
76
+ "login()"
77
+ ],
78
+ "metadata": {},
79
+ "execution_count": null,
80
+ "outputs": []
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "source": [
85
+ "## Step 3: Configuration"
86
+ ],
87
+ "metadata": {}
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "source": [
92
+ "import os\n",
93
+ "import torch\n",
94
+ "import numpy as np\n",
95
+ "\n",
96
+ "# ═══════════════════════════════════════════════════════════════\n",
97
+ "# CONFIGURATION — Edit these values\n",
98
+ "# ═══════════════════════════════════════════════════════════════\n",
99
+ "\n",
100
+ "BASE_MODEL = \"microsoft/deberta-v3-large\" # 435M params, MIT license\n",
101
+ "MAX_LENGTH = 512 # covers 72.4% of clauses\n",
102
+ "HUB_MODEL_ID = \"gaurv007/clauseguard-deberta-v3-large\" # ← your model repo\n",
103
+ "\n",
104
+ "# Stage 1: LEDGAR config\n",
105
+ "STAGE1_EPOCHS = 5 # LEDGAR is large, converges fast\n",
106
+ "STAGE1_LR = 2e-5\n",
107
+ "STAGE1_BATCH = 8\n",
108
+ "STAGE1_GRAD_ACCUM = 4 # effective batch = 32\n",
109
+ "\n",
110
+ "# Stage 2: CUAD config \n",
111
+ "STAGE2_EPOCHS = 20\n",
112
+ "STAGE2_LR = 1e-5 # lower LR for fine-tuning pretrained model\n",
113
+ "STAGE2_BATCH = 8\n",
114
+ "STAGE2_GRAD_ACCUM = 4 # effective batch = 32\n",
115
+ "EARLY_STOPPING_PATIENCE = 3\n",
116
+ "\n",
117
+ "# ASL hyperparameters (from arxiv 2009.14119)\n",
118
+ "ASL_GAMMA_POS = 0\n",
119
+ "ASL_GAMMA_NEG = 4\n",
120
+ "ASL_CLIP = 0.05\n",
121
+ "\n",
122
+ "# Weight decay (DeBERTa default)\n",
123
+ "WEIGHT_DECAY = 0.06\n",
124
+ "WARMUP_RATIO = 0.1\n",
125
+ "\n",
126
+ "SEED = 42\n",
127
+ "\n",
128
+ "# ═══════════════════════════════════════════════════════════════\n",
129
+ "\n",
130
+ "# CUAD 41 label names (must match class_id 0-40 in CUAD dataset)\n",
131
+ "CUAD_LABELS = [\n",
132
+ " \"Document Name\", # 0\n",
133
+ " \"Parties\", # 1\n",
134
+ " \"Agreement Date\", # 2\n",
135
+ " \"Effective Date\", # 3\n",
136
+ " \"Expiration Date\", # 4\n",
137
+ " \"Renewal Term\", # 5\n",
138
+ " \"Notice Period to Terminate Renewal\", # 6\n",
139
+ " \"Governing Law\", # 7\n",
140
+ " \"Most Favored Nation\", # 8\n",
141
+ " \"Non-Compete\", # 9\n",
142
+ " \"Exclusivity\", # 10\n",
143
+ " \"No-Solicit of Customers\", # 11\n",
144
+ " \"No-Solicit of Employees\", # 12\n",
145
+ " \"Non-Disparagement\", # 13\n",
146
+ " \"Termination for Convenience\", # 14\n",
147
+ " \"ROFR/ROFO/ROFN\", # 15\n",
148
+ " \"Change of Control\", # 16\n",
149
+ " \"Anti-Assignment\", # 17\n",
150
+ " \"Revenue/Profit Sharing\", # 18\n",
151
+ " \"Price Restriction\", # 19\n",
152
+ " \"Minimum Commitment\", # 20\n",
153
+ " \"Volume Restriction\", # 21\n",
154
+ " \"IP Ownership Assignment\", # 22\n",
155
+ " \"Joint IP Ownership\", # 23\n",
156
+ " \"License Grant\", # 24\n",
157
+ " \"Non-Transferable License\", # 25\n",
158
+ " \"Affiliate License-Licensor\", # 26\n",
159
+ " \"Affiliate License-Licensee\", # 27\n",
160
+ " \"Unlimited/All-You-Can-Eat License\", # 28\n",
161
+ " \"Irrevocable or Perpetual License\", # 29\n",
162
+ " \"Source Code Escrow\", # 30\n",
163
+ " \"Post-Termination Services\", # 31\n",
164
+ " \"Audit Rights\", # 32\n",
165
+ " \"Uncapped Liability\", # 33\n",
166
+ " \"Cap on Liability\", # 34\n",
167
+ " \"Liquidated Damages\", # 35\n",
168
+ " \"Warranty Duration\", # 36\n",
169
+ " \"Insurance\", # 37\n",
170
+ " \"Covenant Not to Sue\", # 38\n",
171
+ " \"Third Party Beneficiary\", # 39\n",
172
+ " \"Other\", # 40\n",
173
+ "]\n",
174
+ "\n",
175
+ "NUM_CUAD_LABELS = len(CUAD_LABELS) # 41\n",
176
+ "\n",
177
+ "print(f\"🛡️ ClauseGuard v4 Training Configuration\")\n",
178
+ "print(f\" Base model: {BASE_MODEL}\")\n",
179
+ "print(f\" Max length: {MAX_LENGTH}\")\n",
180
+ "print(f\" Hub model: {HUB_MODEL_ID}\")\n",
181
+ "print(f\" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
182
+ "print(f\" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\" if torch.cuda.is_available() else \"\")\n",
183
+ "print(f\" CUAD classes: {NUM_CUAD_LABELS}\")"
184
+ ],
185
+ "metadata": {},
186
+ "execution_count": null,
187
+ "outputs": []
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "source": [
192
+ "## Step 4: Load Datasets"
193
+ ],
194
+ "metadata": {}
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "source": [
199
+ "from datasets import load_dataset, Dataset\n",
200
+ "import pandas as pd\n",
201
+ "from collections import Counter\n",
202
+ "\n",
203
+ "# ═══════════════════════════════════════════════════════════════\n",
204
+ "# Stage 1: LEDGAR (100 classes, single-label)\n",
205
+ "# ═══════════════════════════════════════════════════════════════\n",
206
+ "print(\"📚 Loading LEDGAR dataset...\")\n",
207
+ "ledgar = load_dataset(\"coastalcph/lex_glue\", \"ledgar\")\n",
208
+ "print(f\" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,} | Test: {len(ledgar['test']):,}\")\n",
209
+ "num_ledgar_labels = ledgar['train'].features['label'].num_classes\n",
210
+ "print(f\" Classes: {num_ledgar_labels}\")\n",
211
+ "\n",
212
+ "# ═══════════════════════════════════════════════════════════════\n",
213
+ "# Stage 2: CUAD (41 classes — reformulated for classification)\n",
214
+ "# ═══════════════════════════════════════════════════════════════\n",
215
+ "print(\"\\n📚 Loading CUAD classification dataset...\")\n",
216
+ "cuad_raw = load_dataset(\"dvgodoy/CUAD_v1_Contract_Understanding_clause_classification\", split=\"train\")\n",
217
+ "print(f\" Total rows: {len(cuad_raw):,}\")\n",
218
+ "\n",
219
+ "# Analyze class distribution\n",
220
+ "class_counts = Counter(cuad_raw['class_id'])\n",
221
+ "print(f\" Unique classes: {len(class_counts)}\")\n",
222
+ "print(f\" \\n Class distribution:\")\n",
223
+ "for cid in sorted(class_counts.keys()):\n",
224
+ " label_name = CUAD_LABELS[cid] if cid < len(CUAD_LABELS) else f\"Unknown-{cid}\"\n",
225
+ " count = class_counts[cid]\n",
226
+ " bar = '█' * min(50, count // 10)\n",
227
+ " print(f\" {cid:2d} {label_name:40s} {count:5d} {bar}\")"
228
+ ],
229
+ "metadata": {},
230
+ "execution_count": null,
231
+ "outputs": []
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "source": [
236
+ "## Step 5: Prepare CUAD Train/Val/Test Splits"
237
+ ],
238
+ "metadata": {}
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "source": [
243
+ "from sklearn.model_selection import train_test_split\n",
244
+ "\n",
245
+ "# CUAD only has train split — create val/test by splitting by file_name\n",
246
+ "# (so no data leakage between contracts)\n",
247
+ "cuad_df = cuad_raw.to_pandas()\n",
248
+ "\n",
249
+ "# Get unique file names\n",
250
+ "unique_files = cuad_df['file_name'].unique()\n",
251
+ "print(f\"Unique contracts: {len(unique_files)}\")\n",
252
+ "\n",
253
+ "# Split files 80/10/10\n",
254
+ "train_files, test_files = train_test_split(unique_files, test_size=0.2, random_state=SEED)\n",
255
+ "val_files, test_files = train_test_split(test_files, test_size=0.5, random_state=SEED)\n",
256
+ "\n",
257
+ "cuad_train_df = cuad_df[cuad_df['file_name'].isin(train_files)]\n",
258
+ "cuad_val_df = cuad_df[cuad_df['file_name'].isin(val_files)]\n",
259
+ "cuad_test_df = cuad_df[cuad_df['file_name'].isin(test_files)]\n",
260
+ "\n",
261
+ "print(f\"CUAD splits — Train: {len(cuad_train_df)} | Val: {len(cuad_val_df)} | Test: {len(cuad_test_df)}\")\n",
262
+ "print(f\"Train contracts: {len(train_files)} | Val contracts: {len(val_files)} | Test contracts: {len(test_files)}\")\n",
263
+ "\n",
264
+ "# Convert to HF Dataset\n",
265
+ "cuad_train = Dataset.from_pandas(cuad_train_df.reset_index(drop=True))\n",
266
+ "cuad_val = Dataset.from_pandas(cuad_val_df.reset_index(drop=True))\n",
267
+ "cuad_test = Dataset.from_pandas(cuad_test_df.reset_index(drop=True))\n",
268
+ "\n",
269
+ "# Verify class distribution in each split\n",
270
+ "for name, ds in [(\"Train\", cuad_train), (\"Val\", cuad_val), (\"Test\", cuad_test)]:\n",
271
+ " counts = Counter(ds['class_id'])\n",
272
+ " empty_classes = [i for i in range(NUM_CUAD_LABELS) if counts.get(i, 0) == 0]\n",
273
+ " print(f\" {name}: {len(ds)} rows, {len(counts)} classes present, {len(empty_classes)} classes missing: {empty_classes[:5]}...\")"
274
+ ],
275
+ "metadata": {},
276
+ "execution_count": null,
277
+ "outputs": []
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "source": [
282
+ "## Step 6: Tokenizer & Preprocessing"
283
+ ],
284
+ "metadata": {}
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "source": [
289
+ "from transformers import AutoTokenizer\n",
290
+ "\n",
291
+ "print(f\"Loading tokenizer: {BASE_MODEL}\")\n",
292
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
293
+ "\n",
294
+ "# ── LEDGAR preprocessing (single-label) ──\n",
295
+ "def preprocess_ledgar(examples):\n",
296
+ " tokenized = tokenizer(\n",
297
+ " examples[\"text\"],\n",
298
+ " truncation=True,\n",
299
+ " max_length=MAX_LENGTH,\n",
300
+ " padding=False,\n",
301
+ " )\n",
302
+ " tokenized[\"labels\"] = examples[\"label\"] # int label for CrossEntropy\n",
303
+ " return tokenized\n",
304
+ "\n",
305
+ "# ── CUAD preprocessing (single-label per clause, 41 classes) ──\n",
306
+ "def preprocess_cuad(examples):\n",
307
+ " tokenized = tokenizer(\n",
308
+ " examples[\"clause\"],\n",
309
+ " truncation=True,\n",
310
+ " max_length=MAX_LENGTH,\n",
311
+ " padding=False,\n",
312
+ " )\n",
313
+ " tokenized[\"labels\"] = examples[\"class_id\"] # int label for CrossEntropy + ASL\n",
314
+ " return tokenized\n",
315
+ "\n",
316
+ "print(\"Tokenizing LEDGAR...\")\n",
317
+ "ledgar_tokenized = ledgar.map(\n",
318
+ " preprocess_ledgar, batched=True,\n",
319
+ " remove_columns=ledgar[\"train\"].column_names,\n",
320
+ " desc=\"Tokenizing LEDGAR\"\n",
321
+ ")\n",
322
+ "\n",
323
+ "print(\"Tokenizing CUAD...\")\n",
324
+ "cuad_train_tok = cuad_train.map(\n",
325
+ " preprocess_cuad, batched=True,\n",
326
+ " remove_columns=cuad_train.column_names,\n",
327
+ " desc=\"Tokenizing CUAD train\"\n",
328
+ ")\n",
329
+ "cuad_val_tok = cuad_val.map(\n",
330
+ " preprocess_cuad, batched=True,\n",
331
+ " remove_columns=cuad_val.column_names,\n",
332
+ " desc=\"Tokenizing CUAD val\"\n",
333
+ ")\n",
334
+ "cuad_test_tok = cuad_test.map(\n",
335
+ " preprocess_cuad, batched=True,\n",
336
+ " remove_columns=cuad_test.column_names,\n",
337
+ " desc=\"Tokenizing CUAD test\"\n",
338
+ ")\n",
339
+ "\n",
340
+ "# Check token lengths\n",
341
+ "train_lengths = [len(x) for x in cuad_train_tok['input_ids']]\n",
342
+ "print(f\"\\n📊 CUAD token length stats:\")\n",
343
+ "print(f\" Mean: {np.mean(train_lengths):.0f} | Median: {np.median(train_lengths):.0f}\")\n",
344
+ "print(f\" 95th pct: {np.percentile(train_lengths, 95):.0f} | Max: {max(train_lengths)}\")\n",
345
+ "print(f\" Truncated (>512): {sum(1 for l in train_lengths if l >= MAX_LENGTH)} ({sum(1 for l in train_lengths if l >= MAX_LENGTH)/len(train_lengths)*100:.1f}%)\")\n",
346
+ "print(\"✅ Tokenization complete!\")"
347
+ ],
348
+ "metadata": {},
349
+ "execution_count": null,
350
+ "outputs": []
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "source": [
355
+ "## Step 7: Asymmetric Loss Function\n",
356
+ "\n",
357
+ "From [Asymmetric Loss For Multi-Label Classification](https://arxiv.org/abs/2009.14119) (ICCV 2021).\n",
358
+ "\n",
359
+ "Key idea: Down-weight easy negatives more aggressively than positives. Critical for CUAD where most labels are negative for any given clause."
360
+ ],
361
+ "metadata": {}
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "source": [
366
+ "import torch\n",
367
+ "import torch.nn as nn\n",
368
+ "import torch.nn.functional as F\n",
369
+ "\n",
370
+ "\n",
371
+ "class AsymmetricLoss(nn.Module):\n",
372
+ " \"\"\"\n",
373
+ " Asymmetric Loss from arxiv:2009.14119.\n",
374
+ " \n",
375
+ " For multi-class (single-label) classification with class imbalance:\n",
376
+ " We use the multi-class variant — apply focal-style re-weighting\n",
377
+ " to cross-entropy, with different gamma for correct vs incorrect classes.\n",
378
+ " \n",
379
+ " For multi-label (multi-hot) classification:\n",
380
+ " L+ = (1-p)^γ+ * log(p)\n",
381
+ " L- = (pm)^γ- * log(1-pm), pm = max(p - m, 0)\n",
382
+ " \"\"\"\n",
383
+ " def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,\n",
384
+ " num_classes=None, class_weights=None, mode=\"multi_class\"):\n",
385
+ " super().__init__()\n",
386
+ " self.gamma_pos = gamma_pos\n",
387
+ " self.gamma_neg = gamma_neg\n",
388
+ " self.clip = clip\n",
389
+ " self.eps = eps\n",
390
+ " self.mode = mode\n",
391
+ " \n",
392
+ " # Optional class weights for severe imbalance\n",
393
+ " if class_weights is not None:\n",
394
+ " self.register_buffer('class_weights', torch.tensor(class_weights, dtype=torch.float32))\n",
395
+ " else:\n",
396
+ " self.class_weights = None\n",
397
+ "\n",
398
+ " def forward(self, logits, targets):\n",
399
+ " if self.mode == \"multi_label\":\n",
400
+ " return self._multi_label_loss(logits, targets)\n",
401
+ " else:\n",
402
+ " return self._multi_class_loss(logits, targets)\n",
403
+ " \n",
404
+ " def _multi_class_loss(self, logits, targets):\n",
405
+ " \"\"\"Focal-style cross-entropy with asymmetric gamma for single-label classification.\"\"\"\n",
406
+ " # Standard cross-entropy with class weights\n",
407
+ " if self.class_weights is not None:\n",
408
+ " ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')\n",
409
+ " else:\n",
410
+ " ce_loss = F.cross_entropy(logits, targets, reduction='none')\n",
411
+ " \n",
412
+ " # Apply focal modulation\n",
413
+ " probs = F.softmax(logits, dim=-1)\n",
414
+ " # Get probability of the correct class\n",
415
+ " p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)\n",
416
+ " \n",
417
+ " # Focal weight: (1 - p_t)^gamma\n",
418
+ " # Use gamma_neg for hard examples (low p_t), gamma_pos for easy ones\n",
419
+ " focal_weight = (1 - p_t) ** self.gamma_neg\n",
420
+ " \n",
421
+ " loss = focal_weight * ce_loss\n",
422
+ " return loss.mean()\n",
423
+ "\n",
424
+ " def _multi_label_loss(self, logits, targets):\n",
425
+ " \"\"\"Full ASL for multi-label classification.\"\"\"\n",
426
+ " p = torch.sigmoid(logits)\n",
427
+ " \n",
428
+ " if self.clip is not None and self.clip > 0:\n",
429
+ " p_m = torch.clamp(p - self.clip, min=0)\n",
430
+ " else:\n",
431
+ " p_m = p\n",
432
+ " \n",
433
+ " loss_pos = targets * (1 - p) ** self.gamma_pos * torch.log(p + self.eps)\n",
434
+ " loss_neg = (1 - targets) * p_m ** self.gamma_neg * torch.log(1 - p_m + self.eps)\n",
435
+ " \n",
436
+ " loss = -(loss_pos + loss_neg)\n",
437
+ " return loss.mean()\n",
438
+ "\n",
439
+ "\n",
440
+ "print(\"✅ AsymmetricLoss defined\")\n",
441
+ "print(f\" γ+ = {ASL_GAMMA_POS}, γ- = {ASL_GAMMA_NEG}, clip = {ASL_CLIP}\")"
442
+ ],
443
+ "metadata": {},
444
+ "execution_count": null,
445
+ "outputs": []
446
+ },
447
+ {
448
+ "cell_type": "markdown",
449
+ "source": [
450
+ "## Step 8: Custom Trainer with ASL"
451
+ ],
452
+ "metadata": {}
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "source": [
457
+ "from transformers import Trainer\n",
458
+ "\n",
459
+ "\n",
460
+ "class ASLTrainer(Trainer):\n",
461
+ " \"\"\"Custom Trainer that uses Asymmetric Loss instead of standard CrossEntropy.\"\"\"\n",
462
+ " \n",
463
+ " def __init__(self, *args, asl_loss_fn=None, **kwargs):\n",
464
+ " super().__init__(*args, **kwargs)\n",
465
+ " self.asl = asl_loss_fn\n",
466
+ "\n",
467
+ " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
468
+ " labels = inputs.pop(\"labels\")\n",
469
+ " outputs = model(**inputs)\n",
470
+ " logits = outputs.logits\n",
471
+ " \n",
472
+ " if self.asl is not None:\n",
473
+ " loss = self.asl(logits, labels)\n",
474
+ " else:\n",
475
+ " # Fallback to standard cross-entropy\n",
476
+ " loss = F.cross_entropy(logits, labels)\n",
477
+ " \n",
478
+ " return (loss, outputs) if return_outputs else loss\n",
479
+ "\n",
480
+ "\n",
481
+ "print(\"✅ ASLTrainer defined\")"
482
+ ],
483
+ "metadata": {},
484
+ "execution_count": null,
485
+ "outputs": []
486
+ },
487
+ {
488
+ "cell_type": "markdown",
489
+ "source": [
490
+ "## Step 9: Metrics"
491
+ ],
492
+ "metadata": {}
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "source": [
497
+ "from sklearn.metrics import f1_score, precision_score, recall_score, classification_report\n",
498
+ "\n",
499
+ "\n",
500
+ "def compute_metrics_single_label(eval_pred):\n",
501
+ " \"\"\"Metrics for single-label classification (LEDGAR & CUAD).\"\"\"\n",
502
+ " logits, labels = eval_pred.predictions, eval_pred.label_ids\n",
503
+ " preds = np.argmax(logits, axis=-1)\n",
504
+ " \n",
505
+ " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n",
506
+ " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n",
507
+ " weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n",
508
+ " accuracy = (preds == labels).mean()\n",
509
+ " \n",
510
+ " return {\n",
511
+ " \"accuracy\": accuracy,\n",
512
+ " \"micro_f1\": micro_f1,\n",
513
+ " \"macro_f1\": macro_f1,\n",
514
+ " \"weighted_f1\": weighted_f1,\n",
515
+ " }\n",
516
+ "\n",
517
+ "\n",
518
+ "def compute_metrics_cuad_detailed(eval_pred):\n",
519
+ " \"\"\"Detailed metrics for CUAD — includes per-class F1.\"\"\"\n",
520
+ " logits, labels = eval_pred.predictions, eval_pred.label_ids\n",
521
+ " preds = np.argmax(logits, axis=-1)\n",
522
+ " \n",
523
+ " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n",
524
+ " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n",
525
+ " weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n",
526
+ " accuracy = (preds == labels).mean()\n",
527
+ " \n",
528
+ " # Per-class F1\n",
529
+ " per_class_f1 = f1_score(labels, preds, average=None, zero_division=0)\n",
530
+ " class_metrics = {}\n",
531
+ " for i, f1_val in enumerate(per_class_f1):\n",
532
+ " if i < len(CUAD_LABELS):\n",
533
+ " # Truncate label name for cleaner logging\n",
534
+ " safe_name = CUAD_LABELS[i][:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n",
535
+ " class_metrics[f\"f1_{safe_name}\"] = float(f1_val)\n",
536
+ " \n",
537
+ " return {\n",
538
+ " \"accuracy\": accuracy,\n",
539
+ " \"micro_f1\": micro_f1,\n",
540
+ " \"macro_f1\": macro_f1,\n",
541
+ " \"weighted_f1\": weighted_f1,\n",
542
+ " **class_metrics,\n",
543
+ " }\n",
544
+ "\n",
545
+ "\n",
546
+ "print(\"✅ Metrics functions defined\")"
547
+ ],
548
+ "metadata": {},
549
+ "execution_count": null,
550
+ "outputs": []
551
+ },
552
+ {
553
+ "cell_type": "markdown",
554
+ "source": [
555
+ "---\n",
556
+ "# 🏋️ STAGE 1: Pre-fine-tune on LEDGAR\n",
557
+ "\n",
558
+ "**Goal:** Teach DeBERTa-v3-large what types of contract clauses exist (100 classes, ~60K examples).\n",
559
+ "\n",
560
+ "This stage uses standard cross-entropy loss since LEDGAR is well-balanced.\n",
561
+ "\n",
562
+ "**Expected:** ~85-90% micro-F1 after 3-5 epochs (~1-2 hours on A100)"
563
+ ],
564
+ "metadata": {}
565
+ },
566
+ {
567
+ "cell_type": "code",
568
+ "source": [
569
+ "from transformers import (\n",
570
+ " AutoConfig,\n",
571
+ " AutoModelForSequenceClassification,\n",
572
+ " TrainingArguments,\n",
573
+ " DataCollatorWithPadding,\n",
574
+ " EarlyStoppingCallback,\n",
575
+ ")\n",
576
+ "\n",
577
+ "print(f\"🏋️ STAGE 1: Pre-fine-tune on LEDGAR ({num_ledgar_labels} classes)\")\n",
578
+ "print(f\" Loading {BASE_MODEL}...\")\n",
579
+ "\n",
580
+ "# Load model for Stage 1 (100 classes, single-label)\n",
581
+ "stage1_model = AutoModelForSequenceClassification.from_pretrained(\n",
582
+ " BASE_MODEL,\n",
583
+ " num_labels=num_ledgar_labels,\n",
584
+ " problem_type=\"single_label_classification\",\n",
585
+ " ignore_mismatched_sizes=True,\n",
586
+ ")\n",
587
+ "\n",
588
+ "total_params = sum(p.numel() for p in stage1_model.parameters())\n",
589
+ "trainable_params = sum(p.numel() for p in stage1_model.parameters() if p.requires_grad)\n",
590
+ "print(f\" Total parameters: {total_params:,}\")\n",
591
+ "print(f\" Trainable parameters: {trainable_params:,}\")\n",
592
+ "\n",
593
+ "stage1_args = TrainingArguments(\n",
594
+ " output_dir=\"./stage1_ledgar\",\n",
595
+ " num_train_epochs=STAGE1_EPOCHS,\n",
596
+ " per_device_train_batch_size=STAGE1_BATCH,\n",
597
+ " per_device_eval_batch_size=16,\n",
598
+ " gradient_accumulation_steps=STAGE1_GRAD_ACCUM,\n",
599
+ " learning_rate=STAGE1_LR,\n",
600
+ " weight_decay=WEIGHT_DECAY,\n",
601
+ " warmup_ratio=WARMUP_RATIO,\n",
602
+ " lr_scheduler_type=\"cosine\",\n",
603
+ " eval_strategy=\"epoch\",\n",
604
+ " save_strategy=\"epoch\",\n",
605
+ " save_total_limit=2,\n",
606
+ " load_best_model_at_end=True,\n",
607
+ " metric_for_best_model=\"macro_f1\",\n",
608
+ " greater_is_better=True,\n",
609
+ " bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,\n",
610
+ " fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,\n",
611
+ " logging_strategy=\"steps\",\n",
612
+ " logging_steps=50,\n",
613
+ " logging_first_step=True,\n",
614
+ " disable_tqdm=False, # Keep progress bar in Colab\n",
615
+ " report_to=\"none\",\n",
616
+ " dataloader_num_workers=2,\n",
617
+ " seed=SEED,\n",
618
+ " gradient_checkpointing=True, # Save VRAM on A100\n",
619
+ ")\n",
620
+ "\n",
621
+ "stage1_trainer = Trainer(\n",
622
+ " model=stage1_model,\n",
623
+ " args=stage1_args,\n",
624
+ " train_dataset=ledgar_tokenized[\"train\"],\n",
625
+ " eval_dataset=ledgar_tokenized[\"validation\"],\n",
626
+ " processing_class=tokenizer,\n",
627
+ " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
628
+ " compute_metrics=compute_metrics_single_label,\n",
629
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n",
630
+ ")\n",
631
+ "\n",
632
+ "print(\"\\n🚀 Starting Stage 1 training...\")\n",
633
+ "stage1_result = stage1_trainer.train()\n",
634
+ "print(f\"\\n✅ Stage 1 complete! Loss: {stage1_result.training_loss:.4f}\")"
635
+ ],
636
+ "metadata": {},
637
+ "execution_count": null,
638
+ "outputs": []
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "source": [
643
+ "# Evaluate Stage 1 on LEDGAR test set\n",
644
+ "print(\"📊 Stage 1 — LEDGAR Test Evaluation\")\n",
645
+ "stage1_test = stage1_trainer.evaluate(ledgar_tokenized[\"test\"])\n",
646
+ "print(f\" Accuracy: {stage1_test['eval_accuracy']:.4f}\")\n",
647
+ "print(f\" Micro-F1: {stage1_test['eval_micro_f1']:.4f}\")\n",
648
+ "print(f\" Macro-F1: {stage1_test['eval_macro_f1']:.4f}\")\n",
649
+ "print(f\" Weighted-F1: {stage1_test['eval_weighted_f1']:.4f}\")\n",
650
+ "\n",
651
+ "# Save Stage 1 checkpoint\n",
652
+ "STAGE1_CHECKPOINT = \"./stage1_ledgar_best\"\n",
653
+ "stage1_trainer.save_model(STAGE1_CHECKPOINT)\n",
654
+ "tokenizer.save_pretrained(STAGE1_CHECKPOINT)\n",
655
+ "print(f\"\\n💾 Stage 1 checkpoint saved to {STAGE1_CHECKPOINT}\")"
656
+ ],
657
+ "metadata": {},
658
+ "execution_count": null,
659
+ "outputs": []
660
+ },
661
+ {
662
+ "cell_type": "markdown",
663
+ "source": [
664
+ "---\n",
665
+ "# 🏋️ STAGE 2: Fine-tune on CUAD 41-class with Asymmetric Loss\n",
666
+ "\n",
667
+ "**Goal:** Learn the 41 CUAD contract clause types from the Stage 1 backbone.\n",
668
+ "\n",
669
+ "Key improvements over current ClauseGuard:\n",
670
+ "- DeBERTa-v3-large backbone pre-trained on LEDGAR (Stage 1)\n",
671
+ "- 512 tokens (vs 256) — captures full clause content\n",
672
+ "- Asymmetric Loss for class imbalance\n",
673
+ "- Full fine-tuning (no LoRA bottleneck)\n",
674
+ "\n",
675
+ "**Expected:** 75-87% macro-F1 after 10-20 epochs (~2-4 hours on A100)"
676
+ ],
677
+ "metadata": {}
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "source": [
682
+ "# Free Stage 1 model memory before loading Stage 2\n",
683
+ "del stage1_model, stage1_trainer\n",
684
+ "torch.cuda.empty_cache()\n",
685
+ "import gc; gc.collect()\n",
686
+ "\n",
687
+ "print(f\"🏋️ STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL\")\n",
688
+ "\n",
689
+ "# Load Stage 1 checkpoint with new head (100 → 41 classes)\n",
690
+ "stage2_model = AutoModelForSequenceClassification.from_pretrained(\n",
691
+ " STAGE1_CHECKPOINT,\n",
692
+ " num_labels=NUM_CUAD_LABELS,\n",
693
+ " ignore_mismatched_sizes=True, # classifier head: 100 → 41\n",
694
+ " problem_type=\"single_label_classification\",\n",
695
+ ")\n",
696
+ "\n",
697
+ "print(f\" Loaded Stage 1 backbone with new {NUM_CUAD_LABELS}-class head\")\n",
698
+ "print(f\" Parameters: {sum(p.numel() for p in stage2_model.parameters()):,}\")\n",
699
+ "\n",
700
+ "# Compute class weights from training distribution\n",
701
+ "train_class_counts = Counter(cuad_train_tok['labels'])\n",
702
+ "total_samples = sum(train_class_counts.values())\n",
703
+ "class_weights = []\n",
704
+ "for i in range(NUM_CUAD_LABELS):\n",
705
+ " count = train_class_counts.get(i, 1) # avoid div by zero\n",
706
+ " # Inverse frequency weighting, capped\n",
707
+ " weight = min(10.0, total_samples / (NUM_CUAD_LABELS * count))\n",
708
+ " class_weights.append(weight)\n",
709
+ "\n",
710
+ "print(f\" Class weight range: [{min(class_weights):.2f}, {max(class_weights):.2f}]\")\n",
711
+ "\n",
712
+ "# Create ASL loss\n",
713
+ "asl_loss = AsymmetricLoss(\n",
714
+ " gamma_pos=ASL_GAMMA_POS,\n",
715
+ " gamma_neg=ASL_GAMMA_NEG,\n",
716
+ " clip=ASL_CLIP,\n",
717
+ " num_classes=NUM_CUAD_LABELS,\n",
718
+ " class_weights=class_weights,\n",
719
+ " mode=\"multi_class\", # single-label per clause\n",
720
+ ")\n",
721
+ "# Move to GPU\n",
722
+ "if torch.cuda.is_available():\n",
723
+ " asl_loss = asl_loss.cuda()\n",
724
+ "\n",
725
+ "stage2_args = TrainingArguments(\n",
726
+ " output_dir=\"./stage2_cuad\",\n",
727
+ " num_train_epochs=STAGE2_EPOCHS,\n",
728
+ " per_device_train_batch_size=STAGE2_BATCH,\n",
729
+ " per_device_eval_batch_size=16,\n",
730
+ " gradient_accumulation_steps=STAGE2_GRAD_ACCUM,\n",
731
+ " learning_rate=STAGE2_LR,\n",
732
+ " weight_decay=WEIGHT_DECAY,\n",
733
+ " warmup_ratio=WARMUP_RATIO,\n",
734
+ " lr_scheduler_type=\"cosine\",\n",
735
+ " eval_strategy=\"epoch\",\n",
736
+ " save_strategy=\"epoch\",\n",
737
+ " save_total_limit=3,\n",
738
+ " load_best_model_at_end=True,\n",
739
+ " metric_for_best_model=\"macro_f1\",\n",
740
+ " greater_is_better=True,\n",
741
+ " bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,\n",
742
+ " fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,\n",
743
+ " logging_strategy=\"steps\",\n",
744
+ " logging_steps=25,\n",
745
+ " logging_first_step=True,\n",
746
+ " disable_tqdm=False,\n",
747
+ " report_to=\"none\",\n",
748
+ " push_to_hub=True,\n",
749
+ " hub_model_id=HUB_MODEL_ID,\n",
750
+ " dataloader_num_workers=2,\n",
751
+ " seed=SEED,\n",
752
+ " gradient_checkpointing=True,\n",
753
+ ")\n",
754
+ "\n",
755
+ "stage2_trainer = ASLTrainer(\n",
756
+ " model=stage2_model,\n",
757
+ " args=stage2_args,\n",
758
+ " asl_loss_fn=asl_loss,\n",
759
+ " train_dataset=cuad_train_tok,\n",
760
+ " eval_dataset=cuad_val_tok,\n",
761
+ " processing_class=tokenizer,\n",
762
+ " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
763
+ " compute_metrics=compute_metrics_cuad_detailed,\n",
764
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)],\n",
765
+ ")\n",
766
+ "\n",
767
+ "print(\"\\n🚀 Starting Stage 2 training with Asymmetric Loss...\")\n",
768
+ "stage2_result = stage2_trainer.train()\n",
769
+ "print(f\"\\n✅ Stage 2 complete! Loss: {stage2_result.training_loss:.4f}\")"
770
+ ],
771
+ "metadata": {},
772
+ "execution_count": null,
773
+ "outputs": []
774
+ },
775
+ {
776
+ "cell_type": "markdown",
777
+ "source": [
778
+ "## Step 10: Evaluate Stage 2 on CUAD Test Set"
779
+ ],
780
+ "metadata": {}
781
+ },
782
+ {
783
+ "cell_type": "code",
784
+ "source": [
785
+ "print(\"📊 Stage 2 — CUAD Test Evaluation\")\n",
786
+ "test_results = stage2_trainer.evaluate(cuad_test_tok)\n",
787
+ "\n",
788
+ "print(f\"\\n{'='*60}\")\n",
789
+ "print(f\" CUAD TEST RESULTS (DeBERTa-v3-large + LEDGAR + ASL)\")\n",
790
+ "print(f\"{'='*60}\")\n",
791
+ "print(f\" Accuracy: {test_results['eval_accuracy']:.4f}\")\n",
792
+ "print(f\" Micro-F1: {test_results['eval_micro_f1']:.4f}\")\n",
793
+ "print(f\" Macro-F1: {test_results['eval_macro_f1']:.4f}\")\n",
794
+ "print(f\" Weighted-F1: {test_results['eval_weighted_f1']:.4f}\")\n",
795
+ "print(f\"{'='*60}\")\n",
796
+ "\n",
797
+ "# Per-class F1 report\n",
798
+ "print(f\"\\n Per-class F1 scores:\")\n",
799
+ "print(f\" {'Class':<42s} {'F1':>6s}\")\n",
800
+ "print(f\" {'-'*48}\")\n",
801
+ "\n",
802
+ "zero_f1_classes = []\n",
803
+ "for i, label_name in enumerate(CUAD_LABELS):\n",
804
+ " safe_name = label_name[:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n",
805
+ " key = f\"eval_f1_{safe_name}\"\n",
806
+ " f1_val = test_results.get(key, 0.0)\n",
807
+ " bar = '█' * int(f1_val * 30)\n",
808
+ " status = \"\" if f1_val > 0 else \" ← ZERO\"\n",
809
+ " print(f\" {i:2d} {label_name:<40s} {f1_val:.4f} {bar}{status}\")\n",
810
+ " if f1_val == 0:\n",
811
+ " zero_f1_classes.append(label_name)\n",
812
+ "\n",
813
+ "print(f\"\\n Classes with zero F1: {len(zero_f1_classes)}\")\n",
814
+ "if zero_f1_classes:\n",
815
+ " for c in zero_f1_classes:\n",
816
+ " print(f\" ⚠️ {c}\")"
817
+ ],
818
+ "metadata": {},
819
+ "execution_count": null,
820
+ "outputs": []
821
+ },
822
+ {
823
+ "cell_type": "markdown",
824
+ "source": [
825
+ "## Step 11: Full Classification Report"
826
+ ],
827
+ "metadata": {}
828
+ },
829
+ {
830
+ "cell_type": "code",
831
+ "source": [
832
+ "# Generate full sklearn classification report\n",
833
+ "from sklearn.metrics import classification_report\n",
834
+ "\n",
835
+ "# Get predictions on test set\n",
836
+ "preds_output = stage2_trainer.predict(cuad_test_tok)\n",
837
+ "preds = np.argmax(preds_output.predictions, axis=-1)\n",
838
+ "labels = preds_output.label_ids\n",
839
+ "\n",
840
+ "# Only include labels that appear in test set\n",
841
+ "present_labels = sorted(set(labels) | set(preds))\n",
842
+ "target_names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f\"Class-{i}\" for i in present_labels]\n",
843
+ "\n",
844
+ "report = classification_report(\n",
845
+ " labels, preds,\n",
846
+ " labels=present_labels,\n",
847
+ " target_names=target_names,\n",
848
+ " zero_division=0,\n",
849
+ " digits=4,\n",
850
+ ")\n",
851
+ "print(\"\\n📊 Full Classification Report:\")\n",
852
+ "print(report)"
853
+ ],
854
+ "metadata": {},
855
+ "execution_count": null,
856
+ "outputs": []
857
+ },
858
+ {
859
+ "cell_type": "markdown",
860
+ "source": [
861
+ "## Step 12: Push Final Model to Hub"
862
+ ],
863
+ "metadata": {}
864
+ },
865
+ {
866
+ "cell_type": "code",
867
+ "source": [
868
+ "# Save model with proper label mapping\n",
869
+ "stage2_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}\n",
870
+ "stage2_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}\n",
871
+ "\n",
872
+ "# Save locally\n",
873
+ "FINAL_DIR = \"./clauseguard-deberta-final\"\n",
874
+ "stage2_trainer.save_model(FINAL_DIR)\n",
875
+ "tokenizer.save_pretrained(FINAL_DIR)\n",
876
+ "\n",
877
+ "# Push to Hub\n",
878
+ "print(f\"\\n☁️ Pushing model to Hub: {HUB_MODEL_ID}\")\n",
879
+ "stage2_trainer.push_to_hub(\n",
880
+ " commit_message=(\n",
881
+ " f\"ClauseGuard v4: DeBERTa-v3-large 2-stage (LEDGAR→CUAD) with ASL\\n\"\n",
882
+ " f\"CUAD Test: micro-F1={test_results['eval_micro_f1']:.4f}, \"\n",
883
+ " f\"macro-F1={test_results['eval_macro_f1']:.4f}\"\n",
884
+ " )\n",
885
+ ")\n",
886
+ "\n",
887
+ "print(f\"\\n✅ Model pushed to: https://huggingface.co/{HUB_MODEL_ID}\")"
888
+ ],
889
+ "metadata": {},
890
+ "execution_count": null,
891
+ "outputs": []
892
+ },
893
+ {
894
+ "cell_type": "markdown",
895
+ "source": [
896
+ "## Step 13: Test the Model on Sample Clauses"
897
+ ],
898
+ "metadata": {}
899
+ },
900
+ {
901
+ "cell_type": "code",
902
+ "source": [
903
+ "from transformers import pipeline as hf_pipeline\n",
904
+ "\n",
905
+ "# Load the trained model for inference\n",
906
+ "classifier = hf_pipeline(\n",
907
+ " \"text-classification\",\n",
908
+ " model=stage2_model,\n",
909
+ " tokenizer=tokenizer,\n",
910
+ " top_k=5, # return top 5 predictions\n",
911
+ " device=0 if torch.cuda.is_available() else -1,\n",
912
+ ")\n",
913
+ "\n",
914
+ "test_clauses = [\n",
915
+ " # High-risk clauses\n",
916
+ " \"The Company may terminate this Agreement at any time, with or without cause, upon written notice to the other party.\",\n",
917
+ " \"In no event shall the Company be liable for any indirect, incidental, special, or consequential damages arising out of this Agreement.\",\n",
918
+ " \"All intellectual property developed during the term of this Agreement shall be owned exclusively by the Company.\",\n",
919
+ " \"This Agreement shall be governed by and construed in accordance with the laws of the State of Delaware.\",\n",
920
+ " \"Any disputes arising out of this Agreement shall be resolved through binding arbitration in New York.\",\n",
921
+ " \"The Employee agrees not to compete with the Company for a period of two (2) years following termination.\",\n",
922
+ " # Neutral clauses\n",
923
+ " \"This Agreement shall be effective as of January 1, 2024.\",\n",
924
+ " \"The initial term of this Agreement shall be three (3) years.\",\n",
925
+ " \"Either party may assign this Agreement with the prior written consent of the other party.\",\n",
926
+ "]\n",
927
+ "\n",
928
+ "print(\"🧪 Testing model on sample clauses:\\n\")\n",
929
+ "for clause in test_clauses:\n",
930
+ " results = classifier(clause, truncation=True, max_length=MAX_LENGTH)\n",
931
+ " top = results[0] if isinstance(results[0], dict) else results[0][0]\n",
932
+ " top3 = results[:3] if isinstance(results[0], dict) else results[0][:3]\n",
933
+ " \n",
934
+ " print(f\"📄 \\\"{clause[:90]}{'...' if len(clause) > 90 else ''}\\\"\")\n",
935
+ " for r in top3:\n",
936
+ " score = r['score']\n",
937
+ " bar = '█' * int(score * 20)\n",
938
+ " print(f\" → {r['label']:40s} {score:.4f} {bar}\")\n",
939
+ " print()"
940
+ ],
941
+ "metadata": {},
942
+ "execution_count": null,
943
+ "outputs": []
944
+ },
945
+ {
946
+ "cell_type": "markdown",
947
+ "source": [
948
+ "## Step 14: Generate Updated app.py Integration Code\n",
949
+ "\n",
950
+ "Copy-paste this into your ClauseGuard Space's `app.py` to use the new model."
951
+ ],
952
+ "metadata": {}
953
+ },
954
+ {
955
+ "cell_type": "code",
956
+ "source": [
957
+ "integration_code = f'''\n",
958
+ "# ═══════════════════════════════════════════════════════════════\n",
959
+ "# ClauseGuard v4 — Integration Code\n",
960
+ "# Replace the model loading section in app.py with this:\n",
961
+ "# ═══════════════════════════════════════════════════════════════\n",
962
+ "\n",
963
+ "# OLD (remove these):\n",
964
+ "# base = \"nlpaueb/legal-bert-base-uncased\"\n",
965
+ "# adapter = \"Mokshith31/legalbert-contract-clause-classification\"\n",
966
+ "# from peft import PeftModel\n",
967
+ "\n",
968
+ "# NEW:\n",
969
+ "CLAUSEGUARD_MODEL = \"{HUB_MODEL_ID}\"\n",
970
+ "\n",
971
+ "def _load_cuad_model():\n",
972
+ " global cuad_tokenizer, cuad_model, _model_status\n",
973
+ " if not _HAS_TORCH:\n",
974
+ " _model_status[\"cuad\"] = \"unavailable\"\n",
975
+ " return\n",
976
+ " try:\n",
977
+ " print(f\"[ClauseGuard] Loading classifier: {{CLAUSEGUARD_MODEL}}\")\n",
978
+ " cuad_tokenizer = AutoTokenizer.from_pretrained(CLAUSEGUARD_MODEL)\n",
979
+ " cuad_model = AutoModelForSequenceClassification.from_pretrained(CLAUSEGUARD_MODEL)\n",
980
+ " cuad_model.eval()\n",
981
+ " _model_status[\"cuad\"] = \"loaded\"\n",
982
+ " print(f\"[ClauseGuard] Model loaded: {{sum(p.numel() for p in cuad_model.parameters()):,}} params\")\n",
983
+ " except Exception as e:\n",
984
+ " print(f\"[ClauseGuard] Model load failed: {{e}}\")\n",
985
+ " _model_status[\"cuad\"] = f\"failed: {{e}}\"\n",
986
+ "\n",
987
+ "# In classify_cuad(), change max_length:\n",
988
+ "# max_length=256 → max_length=512\n",
989
+ "#\n",
990
+ "# Also: since the new model is single-label (softmax),\n",
991
+ "# change the prediction logic from sigmoid to:\n",
992
+ "#\n",
993
+ "# probs = torch.softmax(logits, dim=-1)[0] # instead of sigmoid\n",
994
+ "# top_indices = torch.argsort(probs, descending=True)[:5]\n",
995
+ "# for i in top_indices:\n",
996
+ "# if float(probs[i]) > 0.10: # confidence threshold\n",
997
+ "# label = CUAD_LABELS[i]\n",
998
+ "# ...\n",
999
+ "\n",
1000
+ "# No more PEFT dependency needed!\n",
1001
+ "# No more ignore_mismatched_sizes!\n",
1002
+ "# Just load directly — the model already has the correct head.\n",
1003
+ "'''\n",
1004
+ "\n",
1005
+ "print(integration_code)"
1006
+ ],
1007
+ "metadata": {},
1008
+ "execution_count": null,
1009
+ "outputs": []
1010
+ },
1011
+ {
1012
+ "cell_type": "markdown",
1013
+ "source": [
1014
+ "## Step 15: Comparison with Current Model\n",
1015
+ "\n",
1016
+ "| Metric | Current (Legal-BERT + LoRA) | New (DeBERTa-v3-large + ASL) |\n",
1017
+ "|--------|---------------------------|-----------------------------|\n",
1018
+ "| Base model | 110M params | 435M params |\n",
1019
+ "| Training | LoRA (frozen backbone) | Full fine-tune |\n",
1020
+ "| Pre-training | None | LEDGAR (60K, 100 classes) |\n",
1021
+ "| Max tokens | 256 | 512 |\n",
1022
+ "| Loss function | Cross-entropy | Asymmetric Loss |\n",
1023
+ "| Zero-F1 classes | 10 of 41 | TBD (should be much fewer) |\n",
1024
+ "| Macro-F1 | ~50% | Target: 78-87% |\n",
1025
+ "\n",
1026
+ "---\n",
1027
+ "\n",
1028
+ "## ✅ Done!\n",
1029
+ "\n",
1030
+ "Your trained model is at: **https://huggingface.co/gaurv007/clauseguard-deberta-v3-large**\n",
1031
+ "\n",
1032
+ "### Next Steps:\n",
1033
+ "1. Update ClauseGuard Space to use this model (see integration code above)\n",
1034
+ "2. Remove PEFT dependency from requirements.txt\n",
1035
+ "3. Consider training SetFit classifiers for any remaining zero-F1 classes\n",
1036
+ "4. Add OCR support (Feature #2)\n",
1037
+ "5. Add RAG chatbot (Feature #3)"
1038
+ ],
1039
+ "metadata": {}
1040
+ }
1041
+ ]
1042
+ }