{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# ๐Ÿ›ก๏ธ ClauseGuard โ€” Train Legal-BERT Classifier\n", "\n", "This notebook fine-tunes **Legal-BERT** on the CLAUDETTE/LexGLUE `unfair_tos` dataset (9,414 clauses, 8 unfair clause categories).\n", "\n", "**Runtime:** ~30 min on T4 GPU\n", "\n", "**Before running:**\n", "1. Go to `Runtime` โ†’ `Change runtime type` โ†’ Select **T4 GPU**\n", "2. Click `Runtime` โ†’ `Run all`\n", "3. When prompted, paste your HuggingFace token (needs write access)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Step 1: Install Dependencies" ], "metadata": {} }, { "cell_type": "code", "source": [ "!pip install -q transformers datasets scikit-learn accelerate huggingface_hub" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 2: Login to HuggingFace Hub" ], "metadata": {} }, { "cell_type": "code", "source": [ "from huggingface_hub import login\n", "login()" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 3: Load Dataset" ], "metadata": {} }, { "cell_type": "code", "source": [ "from datasets import load_dataset, Sequence, Value\n", "\n", "dataset = load_dataset(\"coastalcph/lex_glue\", \"unfair_tos\")\n", "print(f\"Train: {len(dataset['train'])} | Val: {len(dataset['validation'])} | Test: {len(dataset['test'])}\")\n", "print(f\"Label names: {dataset['train'].features['labels'].feature.names}\")\n", "print(f\"\\nSample: {dataset['train'][10]}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 4: Load Legal-BERT Model" ], "metadata": {} }, { "cell_type": "code", "source": [ "from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer\n", "\n", "MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n", "NUM_LABELS = 8\n", "LABEL_NAMES = [\n", " \"Limitation of liability\",\n", " \"Unilateral termination\",\n", " \"Unilateral change\",\n", " \"Content removal\",\n", " \"Contract by using\",\n", " \"Choice of law\",\n", " \"Jurisdiction\",\n", " \"Arbitration\",\n", "]\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "config = AutoConfig.from_pretrained(\n", " MODEL_NAME,\n", " num_labels=NUM_LABELS,\n", " problem_type=\"multi_label_classification\",\n", " id2label={str(i): n for i, n in enumerate(LABEL_NAMES)},\n", " label2id={n: i for i, n in enumerate(LABEL_NAMES)},\n", ")\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " MODEL_NAME, config=config, ignore_mismatched_sizes=True\n", ")\n", "\n", "print(f\"Parameters: {sum(p.numel() for p in model.parameters()):,}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 5: Preprocess โ€” Multi-hot Float Labels" ], "metadata": {} }, { "cell_type": "code", "source": [ "MAX_LENGTH = 512\n", "\n", "def preprocess(examples):\n", " tokenized = tokenizer(\n", " examples[\"text\"], truncation=True, max_length=MAX_LENGTH, padding=False\n", " )\n", " batch_labels = []\n", " for lbls in examples[\"labels\"]:\n", " vec = [0.0] * NUM_LABELS\n", " for l in lbls:\n", " vec[l] = 1.0\n", " batch_labels.append(vec)\n", " tokenized[\"labels\"] = batch_labels\n", " return tokenized\n", "\n", "print(\"Tokenizing...\")\n", "tokenized_ds = dataset.map(preprocess, batched=True, remove_columns=dataset[\"train\"].column_names)\n", "\n", "# Critical: cast labels to float32 for BCEWithLogitsLoss\n", "for split in tokenized_ds:\n", " tokenized_ds[split] = tokenized_ds[split].cast_column(\"labels\", Sequence(Value(\"float32\")))\n", "\n", "tokenized_ds.set_format(\"torch\")\n", "\n", "# Verify\n", "sample = tokenized_ds[\"train\"][0]\n", "print(f\"Label dtype: {sample['labels'].dtype} โ† must be float32\")\n", "print(f\"Label shape: {sample['labels'].shape}\")\n", "print(\"โœ… Preprocessing done!\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 6: Train!" ], "metadata": {} }, { "cell_type": "code", "source": [ "import numpy as np\n", "import torch\n", "from sklearn.metrics import f1_score, precision_score, recall_score\n", "from transformers import (\n", " DataCollatorWithPadding, Trainer, TrainingArguments, EarlyStoppingCallback\n", ")\n", "\n", "# โ”€โ”€ Change this to your HF username โ”€โ”€\n", "HUB_MODEL_ID = \"gaurv007/clauseguard-legal-bert\"\n", "\n", "def compute_metrics(eval_pred):\n", " logits, labels = eval_pred.predictions, eval_pred.label_ids\n", " probs = 1 / (1 + np.exp(-logits))\n", " preds = (probs > 0.5).astype(int)\n", " labels = labels.astype(int)\n", " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n", " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n", " micro_p = precision_score(labels, preds, average=\"micro\", zero_division=0)\n", " micro_r = recall_score(labels, preds, average=\"micro\", zero_division=0)\n", " per_class = f1_score(labels, preds, average=None, zero_division=0)\n", " class_metrics = {f\"f1_{LABEL_NAMES[i][:15]}\": float(per_class[i]) for i in range(NUM_LABELS)}\n", " return {\"micro_f1\": micro_f1, \"macro_f1\": macro_f1, \"precision\": micro_p, \"recall\": micro_r, **class_metrics}\n", "\n", "training_args = TrainingArguments(\n", " output_dir=\"./clauseguard-model\",\n", " num_train_epochs=20,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=32,\n", " learning_rate=3e-5,\n", " weight_decay=0.01,\n", " warmup_ratio=0.1,\n", " eval_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " save_total_limit=3,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"macro_f1\",\n", " greater_is_better=True,\n", " fp16=torch.cuda.is_available(),\n", " logging_strategy=\"steps\",\n", " logging_steps=25,\n", " logging_first_step=True,\n", " report_to=\"none\",\n", " push_to_hub=True,\n", " hub_model_id=HUB_MODEL_ID,\n", " seed=42,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_ds[\"train\"],\n", " eval_dataset=tokenized_ds[\"validation\"],\n", " processing_class=tokenizer,\n", " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n", " compute_metrics=compute_metrics,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n", ")\n", "\n", "print(f\"๐Ÿš€ Training on: {training_args.device}\")\n", "print(f\" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n", "print(f\" Epochs: {training_args.num_train_epochs}\")\n", "print(f\" Batch size: {training_args.per_device_train_batch_size}\")\n", "print(f\" Push to Hub: {HUB_MODEL_ID}\")\n", "print()\n", "\n", "train_result = trainer.train()\n", "print(f\"\\nโœ… Training complete! Loss: {train_result.training_loss:.4f}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 7: Evaluate on Test Set" ], "metadata": {} }, { "cell_type": "code", "source": [ "print(\"๐Ÿ“Š Evaluating on test set...\")\n", "test_results = trainer.evaluate(tokenized_ds[\"test\"])\n", "\n", "print(f\"\\n{'='*50}\")\n", "print(f\" TEST RESULTS\")\n", "print(f\"{'='*50}\")\n", "print(f\" Micro-F1: {test_results['eval_micro_f1']:.4f}\")\n", "print(f\" Macro-F1: {test_results['eval_macro_f1']:.4f}\")\n", "print(f\" Precision: {test_results['eval_precision']:.4f}\")\n", "print(f\" Recall: {test_results['eval_recall']:.4f}\")\n", "print(f\"{'='*50}\")\n", "print(f\"\\n Per-class F1:\")\n", "for name in LABEL_NAMES:\n", " key = f\"eval_f1_{name[:15]}\"\n", " print(f\" {name:30s} {test_results.get(key, 0):.4f}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 8: Push to HuggingFace Hub" ], "metadata": {} }, { "cell_type": "code", "source": [ "print(f\"โ˜๏ธ Pushing model to Hub: {HUB_MODEL_ID}\")\n", "trainer.push_to_hub(commit_message=\"ClauseGuard Legal-BERT fine-tuned on CLAUDETTE unfair_tos\")\n", "print(f\"\\nโœ… Model pushed! View at: https://huggingface.co/{HUB_MODEL_ID}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 9: Test the Model" ], "metadata": {} }, { "cell_type": "code", "source": [ "from transformers import pipeline\n", "\n", "classifier = pipeline(\n", " \"text-classification\",\n", " model=trainer.model,\n", " tokenizer=tokenizer,\n", " top_k=None,\n", " device=0 if torch.cuda.is_available() else -1,\n", ")\n", "\n", "test_clauses = [\n", " \"The company may terminate your account at any time, with or without cause, with or without notice.\",\n", " \"By using this service, you agree to be bound by these terms.\",\n", " \"In no event shall the company be liable for any indirect, incidental, or consequential damages.\",\n", " \"These terms shall be governed by the laws of the State of California.\",\n", " \"Any disputes shall be resolved through binding arbitration.\",\n", " \"We reserve the right to modify these terms at any time without prior notice.\",\n", " \"The refund will be processed within 30 business days.\",\n", "]\n", "\n", "print(\"๐Ÿงช Testing model on sample clauses:\\n\")\n", "for clause in test_clauses:\n", " results = classifier(clause, truncation=True, max_length=512)\n", " flagged = [r for r in results[0] if r[\"score\"] > 0.5]\n", " if flagged:\n", " flags = \", \".join([f\"{r['label']} ({r['score']:.2f})\" for r in flagged])\n", " print(f\"๐Ÿ”ด \\\"{clause[:80]}...\\\"\")\n", " print(f\" โ†’ {flags}\\n\")\n", " else:\n", " print(f\"โœ… \\\"{clause[:80]}...\\\"\")\n", " print(f\" โ†’ Fair clause\\n\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## โœ… Done!\n", "\n", "Your trained model is now at:\n", "**https://huggingface.co/gaurv007/clauseguard-legal-bert**\n", "\n", "The live demo at **https://huggingface.co/spaces/gaurv007/ClauseGuard** can now be updated to use this model." ], "metadata": {} } ] }