gaurv007 commited on
Commit
05e9815
Β·
verified Β·
1 Parent(s): 597978a

Add Colab training notebook

Browse files
Files changed (1) hide show
  1. ml/ClauseGuard_Training.ipynb +367 -0
ml/ClauseGuard_Training.ipynb ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "# πŸ›‘οΈ ClauseGuard β€” Train Legal-BERT Classifier\n",
23
+ "\n",
24
+ "This notebook fine-tunes **Legal-BERT** on the CLAUDETTE/LexGLUE `unfair_tos` dataset (9,414 clauses, 8 unfair clause categories).\n",
25
+ "\n",
26
+ "**Runtime:** ~30 min on T4 GPU\n",
27
+ "\n",
28
+ "**Before running:**\n",
29
+ "1. Go to `Runtime` β†’ `Change runtime type` β†’ Select **T4 GPU**\n",
30
+ "2. Click `Runtime` β†’ `Run all`\n",
31
+ "3. When prompted, paste your HuggingFace token (needs write access)"
32
+ ],
33
+ "metadata": {}
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "source": [
38
+ "## Step 1: Install Dependencies"
39
+ ],
40
+ "metadata": {}
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "source": [
45
+ "!pip install -q transformers datasets scikit-learn accelerate huggingface_hub"
46
+ ],
47
+ "metadata": {},
48
+ "execution_count": null,
49
+ "outputs": []
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "source": [
54
+ "## Step 2: Login to HuggingFace Hub"
55
+ ],
56
+ "metadata": {}
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "from huggingface_hub import login\n",
62
+ "login()"
63
+ ],
64
+ "metadata": {},
65
+ "execution_count": null,
66
+ "outputs": []
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "source": [
71
+ "## Step 3: Load Dataset"
72
+ ],
73
+ "metadata": {}
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "source": [
78
+ "from datasets import load_dataset, Sequence, Value\n",
79
+ "\n",
80
+ "dataset = load_dataset(\"coastalcph/lex_glue\", \"unfair_tos\")\n",
81
+ "print(f\"Train: {len(dataset['train'])} | Val: {len(dataset['validation'])} | Test: {len(dataset['test'])}\")\n",
82
+ "print(f\"Label names: {dataset['train'].features['labels'].feature.names}\")\n",
83
+ "print(f\"\\nSample: {dataset['train'][10]}\")"
84
+ ],
85
+ "metadata": {},
86
+ "execution_count": null,
87
+ "outputs": []
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "source": [
92
+ "## Step 4: Load Legal-BERT Model"
93
+ ],
94
+ "metadata": {}
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "source": [
99
+ "from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer\n",
100
+ "\n",
101
+ "MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n",
102
+ "NUM_LABELS = 8\n",
103
+ "LABEL_NAMES = [\n",
104
+ " \"Limitation of liability\",\n",
105
+ " \"Unilateral termination\",\n",
106
+ " \"Unilateral change\",\n",
107
+ " \"Content removal\",\n",
108
+ " \"Contract by using\",\n",
109
+ " \"Choice of law\",\n",
110
+ " \"Jurisdiction\",\n",
111
+ " \"Arbitration\",\n",
112
+ "]\n",
113
+ "\n",
114
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
115
+ "\n",
116
+ "config = AutoConfig.from_pretrained(\n",
117
+ " MODEL_NAME,\n",
118
+ " num_labels=NUM_LABELS,\n",
119
+ " problem_type=\"multi_label_classification\",\n",
120
+ " id2label={str(i): n for i, n in enumerate(LABEL_NAMES)},\n",
121
+ " label2id={n: i for i, n in enumerate(LABEL_NAMES)},\n",
122
+ ")\n",
123
+ "\n",
124
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
125
+ " MODEL_NAME, config=config, ignore_mismatched_sizes=True\n",
126
+ ")\n",
127
+ "\n",
128
+ "print(f\"Parameters: {sum(p.numel() for p in model.parameters()):,}\")"
129
+ ],
130
+ "metadata": {},
131
+ "execution_count": null,
132
+ "outputs": []
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "source": [
137
+ "## Step 5: Preprocess β€” Multi-hot Float Labels"
138
+ ],
139
+ "metadata": {}
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "source": [
144
+ "MAX_LENGTH = 512\n",
145
+ "\n",
146
+ "def preprocess(examples):\n",
147
+ " tokenized = tokenizer(\n",
148
+ " examples[\"text\"], truncation=True, max_length=MAX_LENGTH, padding=False\n",
149
+ " )\n",
150
+ " batch_labels = []\n",
151
+ " for lbls in examples[\"labels\"]:\n",
152
+ " vec = [0.0] * NUM_LABELS\n",
153
+ " for l in lbls:\n",
154
+ " vec[l] = 1.0\n",
155
+ " batch_labels.append(vec)\n",
156
+ " tokenized[\"labels\"] = batch_labels\n",
157
+ " return tokenized\n",
158
+ "\n",
159
+ "print(\"Tokenizing...\")\n",
160
+ "tokenized_ds = dataset.map(preprocess, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
161
+ "\n",
162
+ "# Critical: cast labels to float32 for BCEWithLogitsLoss\n",
163
+ "for split in tokenized_ds:\n",
164
+ " tokenized_ds[split] = tokenized_ds[split].cast_column(\"labels\", Sequence(Value(\"float32\")))\n",
165
+ "\n",
166
+ "tokenized_ds.set_format(\"torch\")\n",
167
+ "\n",
168
+ "# Verify\n",
169
+ "sample = tokenized_ds[\"train\"][0]\n",
170
+ "print(f\"Label dtype: {sample['labels'].dtype} ← must be float32\")\n",
171
+ "print(f\"Label shape: {sample['labels'].shape}\")\n",
172
+ "print(\"βœ… Preprocessing done!\")"
173
+ ],
174
+ "metadata": {},
175
+ "execution_count": null,
176
+ "outputs": []
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "source": [
181
+ "## Step 6: Train!"
182
+ ],
183
+ "metadata": {}
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "source": [
188
+ "import numpy as np\n",
189
+ "import torch\n",
190
+ "from sklearn.metrics import f1_score, precision_score, recall_score\n",
191
+ "from transformers import (\n",
192
+ " DataCollatorWithPadding, Trainer, TrainingArguments, EarlyStoppingCallback\n",
193
+ ")\n",
194
+ "\n",
195
+ "# ── Change this to your HF username ──\n",
196
+ "HUB_MODEL_ID = \"gaurv007/clauseguard-legal-bert\"\n",
197
+ "\n",
198
+ "def compute_metrics(eval_pred):\n",
199
+ " logits, labels = eval_pred.predictions, eval_pred.label_ids\n",
200
+ " probs = 1 / (1 + np.exp(-logits))\n",
201
+ " preds = (probs > 0.5).astype(int)\n",
202
+ " labels = labels.astype(int)\n",
203
+ " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n",
204
+ " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n",
205
+ " micro_p = precision_score(labels, preds, average=\"micro\", zero_division=0)\n",
206
+ " micro_r = recall_score(labels, preds, average=\"micro\", zero_division=0)\n",
207
+ " per_class = f1_score(labels, preds, average=None, zero_division=0)\n",
208
+ " class_metrics = {f\"f1_{LABEL_NAMES[i][:15]}\": float(per_class[i]) for i in range(NUM_LABELS)}\n",
209
+ " return {\"micro_f1\": micro_f1, \"macro_f1\": macro_f1, \"precision\": micro_p, \"recall\": micro_r, **class_metrics}\n",
210
+ "\n",
211
+ "training_args = TrainingArguments(\n",
212
+ " output_dir=\"./clauseguard-model\",\n",
213
+ " num_train_epochs=20,\n",
214
+ " per_device_train_batch_size=16,\n",
215
+ " per_device_eval_batch_size=32,\n",
216
+ " learning_rate=3e-5,\n",
217
+ " weight_decay=0.01,\n",
218
+ " warmup_ratio=0.1,\n",
219
+ " eval_strategy=\"epoch\",\n",
220
+ " save_strategy=\"epoch\",\n",
221
+ " save_total_limit=3,\n",
222
+ " load_best_model_at_end=True,\n",
223
+ " metric_for_best_model=\"macro_f1\",\n",
224
+ " greater_is_better=True,\n",
225
+ " fp16=torch.cuda.is_available(),\n",
226
+ " logging_strategy=\"steps\",\n",
227
+ " logging_steps=25,\n",
228
+ " logging_first_step=True,\n",
229
+ " report_to=\"none\",\n",
230
+ " push_to_hub=True,\n",
231
+ " hub_model_id=HUB_MODEL_ID,\n",
232
+ " seed=42,\n",
233
+ ")\n",
234
+ "\n",
235
+ "trainer = Trainer(\n",
236
+ " model=model,\n",
237
+ " args=training_args,\n",
238
+ " train_dataset=tokenized_ds[\"train\"],\n",
239
+ " eval_dataset=tokenized_ds[\"validation\"],\n",
240
+ " processing_class=tokenizer,\n",
241
+ " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
242
+ " compute_metrics=compute_metrics,\n",
243
+ " callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n",
244
+ ")\n",
245
+ "\n",
246
+ "print(f\"πŸš€ Training on: {training_args.device}\")\n",
247
+ "print(f\" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n",
248
+ "print(f\" Epochs: {training_args.num_train_epochs}\")\n",
249
+ "print(f\" Batch size: {training_args.per_device_train_batch_size}\")\n",
250
+ "print(f\" Push to Hub: {HUB_MODEL_ID}\")\n",
251
+ "print()\n",
252
+ "\n",
253
+ "train_result = trainer.train()\n",
254
+ "print(f\"\\nβœ… Training complete! Loss: {train_result.training_loss:.4f}\")"
255
+ ],
256
+ "metadata": {},
257
+ "execution_count": null,
258
+ "outputs": []
259
+ },
260
+ {
261
+ "cell_type": "markdown",
262
+ "source": [
263
+ "## Step 7: Evaluate on Test Set"
264
+ ],
265
+ "metadata": {}
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "source": [
270
+ "print(\"πŸ“Š Evaluating on test set...\")\n",
271
+ "test_results = trainer.evaluate(tokenized_ds[\"test\"])\n",
272
+ "\n",
273
+ "print(f\"\\n{'='*50}\")\n",
274
+ "print(f\" TEST RESULTS\")\n",
275
+ "print(f\"{'='*50}\")\n",
276
+ "print(f\" Micro-F1: {test_results['eval_micro_f1']:.4f}\")\n",
277
+ "print(f\" Macro-F1: {test_results['eval_macro_f1']:.4f}\")\n",
278
+ "print(f\" Precision: {test_results['eval_precision']:.4f}\")\n",
279
+ "print(f\" Recall: {test_results['eval_recall']:.4f}\")\n",
280
+ "print(f\"{'='*50}\")\n",
281
+ "print(f\"\\n Per-class F1:\")\n",
282
+ "for name in LABEL_NAMES:\n",
283
+ " key = f\"eval_f1_{name[:15]}\"\n",
284
+ " print(f\" {name:30s} {test_results.get(key, 0):.4f}\")"
285
+ ],
286
+ "metadata": {},
287
+ "execution_count": null,
288
+ "outputs": []
289
+ },
290
+ {
291
+ "cell_type": "markdown",
292
+ "source": [
293
+ "## Step 8: Push to HuggingFace Hub"
294
+ ],
295
+ "metadata": {}
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "source": [
300
+ "print(f\"☁️ Pushing model to Hub: {HUB_MODEL_ID}\")\n",
301
+ "trainer.push_to_hub(commit_message=\"ClauseGuard Legal-BERT fine-tuned on CLAUDETTE unfair_tos\")\n",
302
+ "print(f\"\\nβœ… Model pushed! View at: https://huggingface.co/{HUB_MODEL_ID}\")"
303
+ ],
304
+ "metadata": {},
305
+ "execution_count": null,
306
+ "outputs": []
307
+ },
308
+ {
309
+ "cell_type": "markdown",
310
+ "source": [
311
+ "## Step 9: Test the Model"
312
+ ],
313
+ "metadata": {}
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "source": [
318
+ "from transformers import pipeline\n",
319
+ "\n",
320
+ "classifier = pipeline(\n",
321
+ " \"text-classification\",\n",
322
+ " model=trainer.model,\n",
323
+ " tokenizer=tokenizer,\n",
324
+ " top_k=None,\n",
325
+ " device=0 if torch.cuda.is_available() else -1,\n",
326
+ ")\n",
327
+ "\n",
328
+ "test_clauses = [\n",
329
+ " \"The company may terminate your account at any time, with or without cause, with or without notice.\",\n",
330
+ " \"By using this service, you agree to be bound by these terms.\",\n",
331
+ " \"In no event shall the company be liable for any indirect, incidental, or consequential damages.\",\n",
332
+ " \"These terms shall be governed by the laws of the State of California.\",\n",
333
+ " \"Any disputes shall be resolved through binding arbitration.\",\n",
334
+ " \"We reserve the right to modify these terms at any time without prior notice.\",\n",
335
+ " \"The refund will be processed within 30 business days.\",\n",
336
+ "]\n",
337
+ "\n",
338
+ "print(\"πŸ§ͺ Testing model on sample clauses:\\n\")\n",
339
+ "for clause in test_clauses:\n",
340
+ " results = classifier(clause, truncation=True, max_length=512)\n",
341
+ " flagged = [r for r in results[0] if r[\"score\"] > 0.5]\n",
342
+ " if flagged:\n",
343
+ " flags = \", \".join([f\"{r['label']} ({r['score']:.2f})\" for r in flagged])\n",
344
+ " print(f\"πŸ”΄ \\\"{clause[:80]}...\\\"\")\n",
345
+ " print(f\" β†’ {flags}\\n\")\n",
346
+ " else:\n",
347
+ " print(f\"βœ… \\\"{clause[:80]}...\\\"\")\n",
348
+ " print(f\" β†’ Fair clause\\n\")"
349
+ ],
350
+ "metadata": {},
351
+ "execution_count": null,
352
+ "outputs": []
353
+ },
354
+ {
355
+ "cell_type": "markdown",
356
+ "source": [
357
+ "## βœ… Done!\n",
358
+ "\n",
359
+ "Your trained model is now at:\n",
360
+ "**https://huggingface.co/gaurv007/clauseguard-legal-bert**\n",
361
+ "\n",
362
+ "The live demo at **https://huggingface.co/spaces/gaurv007/ClauseGuard** can now be updated to use this model."
363
+ ],
364
+ "metadata": {}
365
+ }
366
+ ]
367
+ }