{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# πŸ›‘οΈ ClauseGuard v4 β€” DeBERTa-v3-large 2-Stage Training\n", "\n", "**Goal:** Train a production-grade contract clause classifier that replaces the current Legal-BERT-base (50% F1 β†’ target 80-87% F1)\n", "\n", "## Architecture\n", "| Setting | Value | Source |\n", "|---------|-------|--------|\n", "| Base model | `microsoft/deberta-v3-large` (435M params) | LexGLUE: outperforms Legal-BERT by 7-10pp |\n", "| Max length | 512 tokens | MAUD paper: covers 72.4% of clauses without truncation |\n", "| Loss function | Asymmetric Loss (Ξ³-=4, clip=0.05) | ASL paper (2009.14119): +3-8pp on rare classes |\n", "| Training | Full fine-tuning (no LoRA) | Full FT wins for encoder classification |\n", "\n", "## 2-Stage Training Pipeline\n", "1. **Stage 1 β€” LEDGAR** (60K legal provisions, 100 classes): Teaches \"what types of contract clauses exist\"\n", "2. **Stage 2 β€” CUAD** (41 CUAD classes): Target task with Asymmetric Loss for class imbalance\n", "\n", "**Runtime:** ~8-12 hours on T4 GPU (or ~4-6 hours on A100)\n", "\n", "**Before running:**\n", "1. `Runtime` β†’ `Change runtime type` β†’ **T4 GPU**\n", "2. `Runtime` β†’ `Run all`\n", "3. Paste your HuggingFace token when prompted" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Step 1: Install Dependencies" ], "metadata": {} }, { "cell_type": "code", "source": [ "!pip install -q transformers datasets scikit-learn accelerate huggingface_hub torch\n", "!pip install -q trackio # optional: experiment tracking" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 2: Login to HuggingFace Hub" ], "metadata": {} }, { "cell_type": "code", "source": [ "from huggingface_hub import login\n", "login()" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 3: Configuration" ], "metadata": {} }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import numpy as np\n", "\n", "# ═══════════════════════════════════════════════════════════════\n", "# CONFIGURATION β€” Edit these values\n", "# ═══════════════════════════════════════════════════════════════\n", "\n", "BASE_MODEL = \"microsoft/deberta-v3-large\" # 435M params, MIT license\n", "MAX_LENGTH = 512 # covers 72.4% of clauses\n", "HUB_MODEL_ID = \"gaurv007/clauseguard-deberta-v3-large\" # ← your model repo\n", "\n", "# Stage 1: LEDGAR config\n", "STAGE1_EPOCHS = 5 # LEDGAR is large, converges fast\n", "STAGE1_LR = 2e-5\n", "STAGE1_BATCH = 2 # T4 fp32: reduced for DeBERTa-v3 compatibility\n", "STAGE1_GRAD_ACCUM = 16 # effective batch = 32 (2 * 16)\n", "\n", "# Stage 2: CUAD config \n", "STAGE2_EPOCHS = 20\n", "STAGE2_LR = 1e-5 # lower LR for fine-tuning pretrained model\n", "STAGE2_BATCH = 2 # T4 fp32: reduced for DeBERTa-v3 compatibility\n", "STAGE2_GRAD_ACCUM = 16 # effective batch = 32 (2 * 16)\n", "EARLY_STOPPING_PATIENCE = 3\n", "\n", "# ASL hyperparameters (from arxiv 2009.14119)\n", "ASL_GAMMA_POS = 0\n", "ASL_GAMMA_NEG = 4\n", "ASL_CLIP = 0.05\n", "\n", "# Weight decay (DeBERTa default)\n", "WEIGHT_DECAY = 0.06\n", "WARMUP_RATIO = 0.1\n", "\n", "SEED = 42\n", "\n", "# ═══════════════════════════════════════════════════════════════\n", "\n", "# CUAD 41 label names (must match class_id 0-40 in CUAD dataset)\n", "CUAD_LABELS = [\n", " \"Document Name\", # 0\n", " \"Parties\", # 1\n", " \"Agreement Date\", # 2\n", " \"Effective Date\", # 3\n", " \"Expiration Date\", # 4\n", " \"Renewal Term\", # 5\n", " \"Notice Period to Terminate Renewal\", # 6\n", " \"Governing Law\", # 7\n", " \"Most Favored Nation\", # 8\n", " \"Non-Compete\", # 9\n", " \"Exclusivity\", # 10\n", " \"No-Solicit of Customers\", # 11\n", " \"No-Solicit of Employees\", # 12\n", " \"Non-Disparagement\", # 13\n", " \"Termination for Convenience\", # 14\n", " \"ROFR/ROFO/ROFN\", # 15\n", " \"Change of Control\", # 16\n", " \"Anti-Assignment\", # 17\n", " \"Revenue/Profit Sharing\", # 18\n", " \"Price Restriction\", # 19\n", " \"Minimum Commitment\", # 20\n", " \"Volume Restriction\", # 21\n", " \"IP Ownership Assignment\", # 22\n", " \"Joint IP Ownership\", # 23\n", " \"License Grant\", # 24\n", " \"Non-Transferable License\", # 25\n", " \"Affiliate License-Licensor\", # 26\n", " \"Affiliate License-Licensee\", # 27\n", " \"Unlimited/All-You-Can-Eat License\", # 28\n", " \"Irrevocable or Perpetual License\", # 29\n", " \"Source Code Escrow\", # 30\n", " \"Post-Termination Services\", # 31\n", " \"Audit Rights\", # 32\n", " \"Uncapped Liability\", # 33\n", " \"Cap on Liability\", # 34\n", " \"Liquidated Damages\", # 35\n", " \"Warranty Duration\", # 36\n", " \"Insurance\", # 37\n", " \"Covenant Not to Sue\", # 38\n", " \"Third Party Beneficiary\", # 39\n", " \"Other\", # 40\n", "]\n", "\n", "NUM_CUAD_LABELS = len(CUAD_LABELS) # 41\n", "\n", "print(f\"πŸ›‘οΈ ClauseGuard v4 Training Configuration\")\n", "print(f\" Base model: {BASE_MODEL}\")\n", "print(f\" Max length: {MAX_LENGTH}\")\n", "print(f\" Hub model: {HUB_MODEL_ID}\")\n", "print(f\" GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n", "print(f\" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\" if torch.cuda.is_available() else \"\")\n", "print(f\" CUAD classes: {NUM_CUAD_LABELS}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 4: Load Datasets" ], "metadata": {} }, { "cell_type": "code", "source": [ "from datasets import load_dataset, Dataset\n", "import pandas as pd\n", "from collections import Counter\n", "\n", "# ═══════════════════════════════════════════════════════════════\n", "# Stage 1: LEDGAR (100 classes, single-label)\n", "# ═══════════════════════════════════════════════════════════════\n", "print(\"πŸ“š Loading LEDGAR dataset...\")\n", "ledgar = load_dataset(\"coastalcph/lex_glue\", \"ledgar\")\n", "print(f\" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,} | Test: {len(ledgar['test']):,}\")\n", "num_ledgar_labels = ledgar['train'].features['label'].num_classes\n", "print(f\" Classes: {num_ledgar_labels}\")\n", "\n", "# ═══════════════════════════════════════════════════════════════\n", "# Stage 2: CUAD (41 classes β€” reformulated for classification)\n", "# ═══════════════════════════════════════════════════════════════\n", "print(\"\\nπŸ“š Loading CUAD classification dataset...\")\n", "cuad_raw = load_dataset(\"dvgodoy/CUAD_v1_Contract_Understanding_clause_classification\", split=\"train\")\n", "print(f\" Total rows: {len(cuad_raw):,}\")\n", "\n", "# Analyze class distribution\n", "class_counts = Counter(cuad_raw['class_id'])\n", "print(f\" Unique classes: {len(class_counts)}\")\n", "print(f\" \\n Class distribution:\")\n", "for cid in sorted(class_counts.keys()):\n", " label_name = CUAD_LABELS[cid] if cid < len(CUAD_LABELS) else f\"Unknown-{cid}\"\n", " count = class_counts[cid]\n", " bar = 'β–ˆ' * min(50, count // 10)\n", " print(f\" {cid:2d} {label_name:40s} {count:5d} {bar}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 5: Prepare CUAD Train/Val/Test Splits" ], "metadata": {} }, { "cell_type": "code", "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "# CUAD only has train split β€” create val/test by splitting by file_name\n", "# (so no data leakage between contracts)\n", "cuad_df = cuad_raw.to_pandas()\n", "\n", "# Get unique file names\n", "unique_files = cuad_df['file_name'].unique()\n", "print(f\"Unique contracts: {len(unique_files)}\")\n", "\n", "# Split files 80/10/10\n", "train_files, test_files = train_test_split(unique_files, test_size=0.2, random_state=SEED)\n", "val_files, test_files = train_test_split(test_files, test_size=0.5, random_state=SEED)\n", "\n", "cuad_train_df = cuad_df[cuad_df['file_name'].isin(train_files)]\n", "cuad_val_df = cuad_df[cuad_df['file_name'].isin(val_files)]\n", "cuad_test_df = cuad_df[cuad_df['file_name'].isin(test_files)]\n", "\n", "print(f\"CUAD splits β€” Train: {len(cuad_train_df)} | Val: {len(cuad_val_df)} | Test: {len(cuad_test_df)}\")\n", "print(f\"Train contracts: {len(train_files)} | Val contracts: {len(val_files)} | Test contracts: {len(test_files)}\")\n", "\n", "# Convert to HF Dataset\n", "cuad_train = Dataset.from_pandas(cuad_train_df.reset_index(drop=True))\n", "cuad_val = Dataset.from_pandas(cuad_val_df.reset_index(drop=True))\n", "cuad_test = Dataset.from_pandas(cuad_test_df.reset_index(drop=True))\n", "\n", "# Verify class distribution in each split\n", "for name, ds in [(\"Train\", cuad_train), (\"Val\", cuad_val), (\"Test\", cuad_test)]:\n", " counts = Counter(ds['class_id'])\n", " empty_classes = [i for i in range(NUM_CUAD_LABELS) if counts.get(i, 0) == 0]\n", " print(f\" {name}: {len(ds)} rows, {len(counts)} classes present, {len(empty_classes)} classes missing: {empty_classes[:5]}...\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 6: Tokenizer & Preprocessing" ], "metadata": {} }, { "cell_type": "code", "source": [ "from transformers import AutoTokenizer\n", "\n", "print(f\"Loading tokenizer: {BASE_MODEL}\")\n", "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n", "\n", "# ── LEDGAR preprocessing (single-label) ──\n", "def preprocess_ledgar(examples):\n", " tokenized = tokenizer(\n", " examples[\"text\"],\n", " truncation=True,\n", " max_length=MAX_LENGTH,\n", " padding=False,\n", " )\n", " tokenized[\"labels\"] = examples[\"label\"] # int label for CrossEntropy\n", " return tokenized\n", "\n", "# ── CUAD preprocessing (single-label per clause, 41 classes) ──\n", "def preprocess_cuad(examples):\n", " tokenized = tokenizer(\n", " examples[\"clause\"],\n", " truncation=True,\n", " max_length=MAX_LENGTH,\n", " padding=False,\n", " )\n", " tokenized[\"labels\"] = examples[\"class_id\"] # int label for CrossEntropy + ASL\n", " return tokenized\n", "\n", "print(\"Tokenizing LEDGAR...\")\n", "ledgar_tokenized = ledgar.map(\n", " preprocess_ledgar, batched=True,\n", " remove_columns=ledgar[\"train\"].column_names,\n", " desc=\"Tokenizing LEDGAR\"\n", ")\n", "\n", "print(\"Tokenizing CUAD...\")\n", "cuad_train_tok = cuad_train.map(\n", " preprocess_cuad, batched=True,\n", " remove_columns=cuad_train.column_names,\n", " desc=\"Tokenizing CUAD train\"\n", ")\n", "cuad_val_tok = cuad_val.map(\n", " preprocess_cuad, batched=True,\n", " remove_columns=cuad_val.column_names,\n", " desc=\"Tokenizing CUAD val\"\n", ")\n", "cuad_test_tok = cuad_test.map(\n", " preprocess_cuad, batched=True,\n", " remove_columns=cuad_test.column_names,\n", " desc=\"Tokenizing CUAD test\"\n", ")\n", "\n", "# Check token lengths\n", "train_lengths = [len(x) for x in cuad_train_tok['input_ids']]\n", "print(f\"\\nπŸ“Š CUAD token length stats:\")\n", "print(f\" Mean: {np.mean(train_lengths):.0f} | Median: {np.median(train_lengths):.0f}\")\n", "print(f\" 95th pct: {np.percentile(train_lengths, 95):.0f} | Max: {max(train_lengths)}\")\n", "print(f\" Truncated (>512): {sum(1 for l in train_lengths if l >= MAX_LENGTH)} ({sum(1 for l in train_lengths if l >= MAX_LENGTH)/len(train_lengths)*100:.1f}%)\")\n", "print(\"βœ… Tokenization complete!\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 7: Asymmetric Loss Function\n", "\n", "From [Asymmetric Loss For Multi-Label Classification](https://arxiv.org/abs/2009.14119) (ICCV 2021).\n", "\n", "Key idea: Down-weight easy negatives more aggressively than positives. Critical for CUAD where most labels are negative for any given clause." ], "metadata": {} }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "class AsymmetricLoss(nn.Module):\n", " \"\"\"\n", " Asymmetric Loss from arxiv:2009.14119.\n", " \n", " For multi-class (single-label) classification with class imbalance:\n", " We use the multi-class variant β€” apply focal-style re-weighting\n", " to cross-entropy, with different gamma for correct vs incorrect classes.\n", " \n", " For multi-label (multi-hot) classification:\n", " L+ = (1-p)^Ξ³+ * log(p)\n", " L- = (pm)^Ξ³- * log(1-pm), pm = max(p - m, 0)\n", " \"\"\"\n", " def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,\n", " num_classes=None, class_weights=None, mode=\"multi_class\"):\n", " super().__init__()\n", " self.gamma_pos = gamma_pos\n", " self.gamma_neg = gamma_neg\n", " self.clip = clip\n", " self.eps = eps\n", " self.mode = mode\n", " \n", " # Optional class weights for severe imbalance\n", " if class_weights is not None:\n", " self.register_buffer('class_weights', torch.tensor(class_weights, dtype=torch.float32))\n", " else:\n", " self.class_weights = None\n", "\n", " def forward(self, logits, targets):\n", " if self.mode == \"multi_label\":\n", " return self._multi_label_loss(logits, targets)\n", " else:\n", " return self._multi_class_loss(logits, targets)\n", " \n", " def _multi_class_loss(self, logits, targets):\n", " \"\"\"Focal-style cross-entropy with asymmetric gamma for single-label classification.\"\"\"\n", " # Standard cross-entropy with class weights\n", " if self.class_weights is not None:\n", " ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')\n", " else:\n", " ce_loss = F.cross_entropy(logits, targets, reduction='none')\n", " \n", " # Apply focal modulation\n", " probs = F.softmax(logits, dim=-1)\n", " # Get probability of the correct class\n", " p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)\n", " \n", " # Focal weight: (1 - p_t)^gamma\n", " # Use gamma_neg for hard examples (low p_t), gamma_pos for easy ones\n", " focal_weight = (1 - p_t) ** self.gamma_neg\n", " \n", " loss = focal_weight * ce_loss\n", " return loss.mean()\n", "\n", " def _multi_label_loss(self, logits, targets):\n", " \"\"\"Full ASL for multi-label classification.\"\"\"\n", " p = torch.sigmoid(logits)\n", " \n", " if self.clip is not None and self.clip > 0:\n", " p_m = torch.clamp(p - self.clip, min=0)\n", " else:\n", " p_m = p\n", " \n", " loss_pos = targets * (1 - p) ** self.gamma_pos * torch.log(p + self.eps)\n", " loss_neg = (1 - targets) * p_m ** self.gamma_neg * torch.log(1 - p_m + self.eps)\n", " \n", " loss = -(loss_pos + loss_neg)\n", " return loss.mean()\n", "\n", "\n", "print(\"βœ… AsymmetricLoss defined\")\n", "print(f\" Ξ³+ = {ASL_GAMMA_POS}, Ξ³- = {ASL_GAMMA_NEG}, clip = {ASL_CLIP}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 8: Custom Trainer with ASL" ], "metadata": {} }, { "cell_type": "code", "source": [ "from transformers import Trainer\n", "\n", "\n", "class ASLTrainer(Trainer):\n", " \"\"\"Custom Trainer that uses Asymmetric Loss instead of standard CrossEntropy.\"\"\"\n", " \n", " def __init__(self, *args, asl_loss_fn=None, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " self.asl = asl_loss_fn\n", "\n", " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " \n", " if self.asl is not None:\n", " loss = self.asl(logits, labels)\n", " else:\n", " # Fallback to standard cross-entropy\n", " loss = F.cross_entropy(logits, labels)\n", " \n", " return (loss, outputs) if return_outputs else loss\n", "\n", "\n", "print(\"βœ… ASLTrainer defined\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 9: Metrics" ], "metadata": {} }, { "cell_type": "code", "source": [ "from sklearn.metrics import f1_score, precision_score, recall_score, classification_report\n", "\n", "\n", "def compute_metrics_single_label(eval_pred):\n", " \"\"\"Metrics for single-label classification (LEDGAR & CUAD).\"\"\"\n", " logits, labels = eval_pred.predictions, eval_pred.label_ids\n", " preds = np.argmax(logits, axis=-1)\n", " \n", " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n", " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n", " weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n", " accuracy = (preds == labels).mean()\n", " \n", " return {\n", " \"accuracy\": accuracy,\n", " \"micro_f1\": micro_f1,\n", " \"macro_f1\": macro_f1,\n", " \"weighted_f1\": weighted_f1,\n", " }\n", "\n", "\n", "def compute_metrics_cuad_detailed(eval_pred):\n", " \"\"\"Detailed metrics for CUAD β€” includes per-class F1.\"\"\"\n", " logits, labels = eval_pred.predictions, eval_pred.label_ids\n", " preds = np.argmax(logits, axis=-1)\n", " \n", " micro_f1 = f1_score(labels, preds, average=\"micro\", zero_division=0)\n", " macro_f1 = f1_score(labels, preds, average=\"macro\", zero_division=0)\n", " weighted_f1 = f1_score(labels, preds, average=\"weighted\", zero_division=0)\n", " accuracy = (preds == labels).mean()\n", " \n", " # Per-class F1\n", " per_class_f1 = f1_score(labels, preds, average=None, zero_division=0)\n", " class_metrics = {}\n", " for i, f1_val in enumerate(per_class_f1):\n", " if i < len(CUAD_LABELS):\n", " # Truncate label name for cleaner logging\n", " safe_name = CUAD_LABELS[i][:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n", " class_metrics[f\"f1_{safe_name}\"] = float(f1_val)\n", " \n", " return {\n", " \"accuracy\": accuracy,\n", " \"micro_f1\": micro_f1,\n", " \"macro_f1\": macro_f1,\n", " \"weighted_f1\": weighted_f1,\n", " **class_metrics,\n", " }\n", "\n", "\n", "print(\"βœ… Metrics functions defined\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "---\n", "# πŸ‹οΈ STAGE 1: Pre-fine-tune on LEDGAR\n", "\n", "**Goal:** Teach DeBERTa-v3-large what types of contract clauses exist (100 classes, ~60K examples).\n", "\n", "This stage uses standard cross-entropy loss since LEDGAR is well-balanced.\n", "\n", "**Expected:** ~85-90% micro-F1 after 3-5 epochs (~3-5 hours on T4, ~1-2 hours on A100)" ], "metadata": {} }, { "cell_type": "code", "source": [ "from transformers import (\n", " AutoConfig,\n", " AutoModelForSequenceClassification,\n", " TrainingArguments,\n", " DataCollatorWithPadding,\n", " EarlyStoppingCallback,\n", ")\n", "\n", "print(f\"πŸ‹οΈ STAGE 1: Pre-fine-tune on LEDGAR ({num_ledgar_labels} classes)\")\n", "print(f\" Loading {BASE_MODEL}...\")\n", "\n", "# Load model for Stage 1 (100 classes, single-label)\n", "stage1_model = AutoModelForSequenceClassification.from_pretrained(\n", " BASE_MODEL,\n", " num_labels=num_ledgar_labels,\n", " problem_type=\"single_label_classification\",\n", " ignore_mismatched_sizes=True,\n", ")\n", "\n", "total_params = sum(p.numel() for p in stage1_model.parameters())\n", "trainable_params = sum(p.numel() for p in stage1_model.parameters() if p.requires_grad)\n", "print(f\" Total parameters: {total_params:,}\")\n", "print(f\" Trainable parameters: {trainable_params:,}\")\n", "\n", "stage1_args = TrainingArguments(\n", " output_dir=\"./stage1_ledgar\",\n", " num_train_epochs=STAGE1_EPOCHS,\n", " per_device_train_batch_size=STAGE1_BATCH,\n", " per_device_eval_batch_size=4,\n", " gradient_accumulation_steps=STAGE1_GRAD_ACCUM,\n", " learning_rate=STAGE1_LR,\n", " weight_decay=WEIGHT_DECAY,\n", " warmup_ratio=WARMUP_RATIO,\n", " lr_scheduler_type=\"cosine\",\n", " eval_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " save_total_limit=2,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"macro_f1\",\n", " greater_is_better=True,\n", " bf16=False, # DeBERTa-v3 breaks with fp16 gradient scaler; fp32 is safest on T4\n", " fp16=False,\n", " logging_strategy=\"steps\",\n", " logging_steps=50,\n", " logging_first_step=True,\n", " disable_tqdm=False,\n", " report_to=\"none\",\n", " dataloader_num_workers=2,\n", " seed=SEED,\n", " gradient_checkpointing=True, # Critical for T4 (16GB VRAM)\n", ")\n", "\n", "stage1_trainer = Trainer(\n", " model=stage1_model,\n", " args=stage1_args,\n", " train_dataset=ledgar_tokenized[\"train\"],\n", " eval_dataset=ledgar_tokenized[\"validation\"],\n", " processing_class=tokenizer,\n", " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n", " compute_metrics=compute_metrics_single_label,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n", ")\n", "\n", "print(\"\\nπŸš€ Starting Stage 1 training...\")\n", "stage1_result = stage1_trainer.train()\n", "print(f\"\\nβœ… Stage 1 complete! Loss: {stage1_result.training_loss:.4f}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Evaluate Stage 1 on LEDGAR test set\n", "print(\"πŸ“Š Stage 1 β€” LEDGAR Test Evaluation\")\n", "stage1_test = stage1_trainer.evaluate(ledgar_tokenized[\"test\"])\n", "print(f\" Accuracy: {stage1_test['eval_accuracy']:.4f}\")\n", "print(f\" Micro-F1: {stage1_test['eval_micro_f1']:.4f}\")\n", "print(f\" Macro-F1: {stage1_test['eval_macro_f1']:.4f}\")\n", "print(f\" Weighted-F1: {stage1_test['eval_weighted_f1']:.4f}\")\n", "\n", "# Save Stage 1 checkpoint\n", "STAGE1_CHECKPOINT = \"./stage1_ledgar_best\"\n", "stage1_trainer.save_model(STAGE1_CHECKPOINT)\n", "tokenizer.save_pretrained(STAGE1_CHECKPOINT)\n", "print(f\"\\nπŸ’Ύ Stage 1 checkpoint saved to {STAGE1_CHECKPOINT}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "---\n", "# πŸ‹οΈ STAGE 2: Fine-tune on CUAD 41-class with Asymmetric Loss\n", "\n", "**Goal:** Learn the 41 CUAD contract clause types from the Stage 1 backbone.\n", "\n", "Key improvements over current ClauseGuard:\n", "- DeBERTa-v3-large backbone pre-trained on LEDGAR (Stage 1)\n", "- 512 tokens (vs 256) β€” captures full clause content\n", "- Asymmetric Loss for class imbalance\n", "- Full fine-tuning (no LoRA bottleneck)\n", "\n", "**Expected:** 75-87% macro-F1 after 10-20 epochs (~5-8 hours on T4, ~2-4 hours on A100)" ], "metadata": {} }, { "cell_type": "code", "source": [ "# Free Stage 1 model memory before loading Stage 2\n", "del stage1_model, stage1_trainer\n", "torch.cuda.empty_cache()\n", "import gc; gc.collect()\n", "\n", "print(f\"πŸ‹οΈ STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL\")\n", "\n", "# Load Stage 1 checkpoint with new head (100 β†’ 41 classes)\n", "stage2_model = AutoModelForSequenceClassification.from_pretrained(\n", " STAGE1_CHECKPOINT,\n", " num_labels=NUM_CUAD_LABELS,\n", " ignore_mismatched_sizes=True, # classifier head: 100 β†’ 41\n", " problem_type=\"single_label_classification\",\n", ")\n", "\n", "print(f\" Loaded Stage 1 backbone with new {NUM_CUAD_LABELS}-class head\")\n", "print(f\" Parameters: {sum(p.numel() for p in stage2_model.parameters()):,}\")\n", "\n", "# Compute class weights from training distribution\n", "train_class_counts = Counter(cuad_train_tok['labels'])\n", "total_samples = sum(train_class_counts.values())\n", "class_weights = []\n", "for i in range(NUM_CUAD_LABELS):\n", " count = train_class_counts.get(i, 1) # avoid div by zero\n", " # Inverse frequency weighting, capped\n", " weight = min(10.0, total_samples / (NUM_CUAD_LABELS * count))\n", " class_weights.append(weight)\n", "\n", "print(f\" Class weight range: [{min(class_weights):.2f}, {max(class_weights):.2f}]\")\n", "\n", "# Create ASL loss\n", "asl_loss = AsymmetricLoss(\n", " gamma_pos=ASL_GAMMA_POS,\n", " gamma_neg=ASL_GAMMA_NEG,\n", " clip=ASL_CLIP,\n", " num_classes=NUM_CUAD_LABELS,\n", " class_weights=class_weights,\n", " mode=\"multi_class\", # single-label per clause\n", ")\n", "# Move to GPU\n", "if torch.cuda.is_available():\n", " asl_loss = asl_loss.cuda()\n", "\n", "stage2_args = TrainingArguments(\n", " output_dir=\"./stage2_cuad\",\n", " num_train_epochs=STAGE2_EPOCHS,\n", " per_device_train_batch_size=STAGE2_BATCH,\n", " per_device_eval_batch_size=4,\n", " gradient_accumulation_steps=STAGE2_GRAD_ACCUM,\n", " learning_rate=STAGE2_LR,\n", " weight_decay=WEIGHT_DECAY,\n", " warmup_ratio=WARMUP_RATIO,\n", " lr_scheduler_type=\"cosine\",\n", " eval_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " save_total_limit=3,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"macro_f1\",\n", " greater_is_better=True,\n", " bf16=False, # DeBERTa-v3 breaks with fp16 gradient scaler; fp32 is safest on T4\n", " fp16=False,\n", " logging_strategy=\"steps\",\n", " logging_steps=25,\n", " logging_first_step=True,\n", " disable_tqdm=False,\n", " report_to=\"none\",\n", " push_to_hub=True,\n", " hub_model_id=HUB_MODEL_ID,\n", " dataloader_num_workers=2,\n", " seed=SEED,\n", " gradient_checkpointing=True,\n", ")\n", "\n", "stage2_trainer = ASLTrainer(\n", " model=stage2_model,\n", " args=stage2_args,\n", " asl_loss_fn=asl_loss,\n", " train_dataset=cuad_train_tok,\n", " eval_dataset=cuad_val_tok,\n", " processing_class=tokenizer,\n", " data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n", " compute_metrics=compute_metrics_cuad_detailed,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)],\n", ")\n", "\n", "print(\"\\nπŸš€ Starting Stage 2 training with Asymmetric Loss...\")\n", "stage2_result = stage2_trainer.train()\n", "print(f\"\\nβœ… Stage 2 complete! Loss: {stage2_result.training_loss:.4f}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 10: Evaluate Stage 2 on CUAD Test Set" ], "metadata": {} }, { "cell_type": "code", "source": [ "print(\"πŸ“Š Stage 2 β€” CUAD Test Evaluation\")\n", "test_results = stage2_trainer.evaluate(cuad_test_tok)\n", "\n", "print(f\"\\n{'='*60}\")\n", "print(f\" CUAD TEST RESULTS (DeBERTa-v3-large + LEDGAR + ASL)\")\n", "print(f\"{'='*60}\")\n", "print(f\" Accuracy: {test_results['eval_accuracy']:.4f}\")\n", "print(f\" Micro-F1: {test_results['eval_micro_f1']:.4f}\")\n", "print(f\" Macro-F1: {test_results['eval_macro_f1']:.4f}\")\n", "print(f\" Weighted-F1: {test_results['eval_weighted_f1']:.4f}\")\n", "print(f\"{'='*60}\")\n", "\n", "# Per-class F1 report\n", "print(f\"\\n Per-class F1 scores:\")\n", "print(f\" {'Class':<42s} {'F1':>6s}\")\n", "print(f\" {'-'*48}\")\n", "\n", "zero_f1_classes = []\n", "for i, label_name in enumerate(CUAD_LABELS):\n", " safe_name = label_name[:20].replace(\" \", \"_\").replace(\"/\", \"_\")\n", " key = f\"eval_f1_{safe_name}\"\n", " f1_val = test_results.get(key, 0.0)\n", " bar = 'β–ˆ' * int(f1_val * 30)\n", " status = \"\" if f1_val > 0 else \" ← ZERO\"\n", " print(f\" {i:2d} {label_name:<40s} {f1_val:.4f} {bar}{status}\")\n", " if f1_val == 0:\n", " zero_f1_classes.append(label_name)\n", "\n", "print(f\"\\n Classes with zero F1: {len(zero_f1_classes)}\")\n", "if zero_f1_classes:\n", " for c in zero_f1_classes:\n", " print(f\" ⚠️ {c}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 11: Full Classification Report" ], "metadata": {} }, { "cell_type": "code", "source": [ "# Generate full sklearn classification report\n", "from sklearn.metrics import classification_report\n", "\n", "# Get predictions on test set\n", "preds_output = stage2_trainer.predict(cuad_test_tok)\n", "preds = np.argmax(preds_output.predictions, axis=-1)\n", "labels = preds_output.label_ids\n", "\n", "# Only include labels that appear in test set\n", "present_labels = sorted(set(labels) | set(preds))\n", "target_names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f\"Class-{i}\" for i in present_labels]\n", "\n", "report = classification_report(\n", " labels, preds,\n", " labels=present_labels,\n", " target_names=target_names,\n", " zero_division=0,\n", " digits=4,\n", ")\n", "print(\"\\nπŸ“Š Full Classification Report:\")\n", "print(report)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 12: Push Final Model to Hub" ], "metadata": {} }, { "cell_type": "code", "source": [ "# Save model with proper label mapping\n", "stage2_model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}\n", "stage2_model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}\n", "\n", "# Save locally\n", "FINAL_DIR = \"./clauseguard-deberta-final\"\n", "stage2_trainer.save_model(FINAL_DIR)\n", "tokenizer.save_pretrained(FINAL_DIR)\n", "\n", "# Push to Hub\n", "print(f\"\\n☁️ Pushing model to Hub: {HUB_MODEL_ID}\")\n", "stage2_trainer.push_to_hub(\n", " commit_message=(\n", " f\"ClauseGuard v4: DeBERTa-v3-large 2-stage (LEDGARβ†’CUAD) with ASL\\n\"\n", " f\"CUAD Test: micro-F1={test_results['eval_micro_f1']:.4f}, \"\n", " f\"macro-F1={test_results['eval_macro_f1']:.4f}\"\n", " )\n", ")\n", "\n", "print(f\"\\nβœ… Model pushed to: https://huggingface.co/{HUB_MODEL_ID}\")" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 13: Test the Model on Sample Clauses" ], "metadata": {} }, { "cell_type": "code", "source": [ "from transformers import pipeline as hf_pipeline\n", "\n", "# Load the trained model for inference\n", "classifier = hf_pipeline(\n", " \"text-classification\",\n", " model=stage2_model,\n", " tokenizer=tokenizer,\n", " top_k=5, # return top 5 predictions\n", " device=0 if torch.cuda.is_available() else -1,\n", ")\n", "\n", "test_clauses = [\n", " # High-risk clauses\n", " \"The Company may terminate this Agreement at any time, with or without cause, upon written notice to the other party.\",\n", " \"In no event shall the Company be liable for any indirect, incidental, special, or consequential damages arising out of this Agreement.\",\n", " \"All intellectual property developed during the term of this Agreement shall be owned exclusively by the Company.\",\n", " \"This Agreement shall be governed by and construed in accordance with the laws of the State of Delaware.\",\n", " \"Any disputes arising out of this Agreement shall be resolved through binding arbitration in New York.\",\n", " \"The Employee agrees not to compete with the Company for a period of two (2) years following termination.\",\n", " # Neutral clauses\n", " \"This Agreement shall be effective as of January 1, 2024.\",\n", " \"The initial term of this Agreement shall be three (3) years.\",\n", " \"Either party may assign this Agreement with the prior written consent of the other party.\",\n", "]\n", "\n", "print(\"πŸ§ͺ Testing model on sample clauses:\\n\")\n", "for clause in test_clauses:\n", " results = classifier(clause, truncation=True, max_length=MAX_LENGTH)\n", " top = results[0] if isinstance(results[0], dict) else results[0][0]\n", " top3 = results[:3] if isinstance(results[0], dict) else results[0][:3]\n", " \n", " print(f\"πŸ“„ \\\"{clause[:90]}{'...' if len(clause) > 90 else ''}\\\"\")\n", " for r in top3:\n", " score = r['score']\n", " bar = 'β–ˆ' * int(score * 20)\n", " print(f\" β†’ {r['label']:40s} {score:.4f} {bar}\")\n", " print()" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 14: Generate Updated app.py Integration Code\n", "\n", "Copy-paste this into your ClauseGuard Space's `app.py` to use the new model." ], "metadata": {} }, { "cell_type": "code", "source": [ "integration_code = f'''\n", "# ═══════════════════════════════════════════════════════════════\n", "# ClauseGuard v4 β€” Integration Code\n", "# Replace the model loading section in app.py with this:\n", "# ═══════════════════════════════════════════════════════════════\n", "\n", "# OLD (remove these):\n", "# base = \"nlpaueb/legal-bert-base-uncased\"\n", "# adapter = \"Mokshith31/legalbert-contract-clause-classification\"\n", "# from peft import PeftModel\n", "\n", "# NEW:\n", "CLAUSEGUARD_MODEL = \"{HUB_MODEL_ID}\"\n", "\n", "def _load_cuad_model():\n", " global cuad_tokenizer, cuad_model, _model_status\n", " if not _HAS_TORCH:\n", " _model_status[\"cuad\"] = \"unavailable\"\n", " return\n", " try:\n", " print(f\"[ClauseGuard] Loading classifier: {{CLAUSEGUARD_MODEL}}\")\n", " cuad_tokenizer = AutoTokenizer.from_pretrained(CLAUSEGUARD_MODEL)\n", " cuad_model = AutoModelForSequenceClassification.from_pretrained(CLAUSEGUARD_MODEL)\n", " cuad_model.eval()\n", " _model_status[\"cuad\"] = \"loaded\"\n", " print(f\"[ClauseGuard] Model loaded: {{sum(p.numel() for p in cuad_model.parameters()):,}} params\")\n", " except Exception as e:\n", " print(f\"[ClauseGuard] Model load failed: {{e}}\")\n", " _model_status[\"cuad\"] = f\"failed: {{e}}\"\n", "\n", "# In classify_cuad(), change max_length:\n", "# max_length=256 β†’ max_length=512\n", "#\n", "# Also: since the new model is single-label (softmax),\n", "# change the prediction logic from sigmoid to:\n", "#\n", "# probs = torch.softmax(logits, dim=-1)[0] # instead of sigmoid\n", "# top_indices = torch.argsort(probs, descending=True)[:5]\n", "# for i in top_indices:\n", "# if float(probs[i]) > 0.10: # confidence threshold\n", "# label = CUAD_LABELS[i]\n", "# ...\n", "\n", "# No more PEFT dependency needed!\n", "# No more ignore_mismatched_sizes!\n", "# Just load directly β€” the model already has the correct head.\n", "'''\n", "\n", "print(integration_code)" ], "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Step 15: Comparison with Current Model\n", "\n", "| Metric | Current (Legal-BERT + LoRA) | New (DeBERTa-v3-large + ASL) |\n", "|--------|---------------------------|-----------------------------|\n", "| Base model | 110M params | 435M params |\n", "| Training | LoRA (frozen backbone) | Full fine-tune |\n", "| Pre-training | None | LEDGAR (60K, 100 classes) |\n", "| Max tokens | 256 | 512 |\n", "| Loss function | Cross-entropy | Asymmetric Loss |\n", "| Zero-F1 classes | 10 of 41 | TBD (should be much fewer) |\n", "| Macro-F1 | ~50% | Target: 78-87% |\n", "\n", "---\n", "\n", "## βœ… Done!\n", "\n", "Your trained model is at: **https://huggingface.co/gaurv007/clauseguard-deberta-v3-large**\n", "\n", "### Next Steps:\n", "1. Update ClauseGuard Space to use this model (see integration code above)\n", "2. Remove PEFT dependency from requirements.txt\n", "3. Consider training SetFit classifiers for any remaining zero-F1 classes\n", "4. Add OCR support (Feature #2)\n", "5. Add RAG chatbot (Feature #3)" ], "metadata": {} } ] }