""" ClauseGuard v4 — 2-Stage DeBERTa-v3-large Training Script ═══════════════════════════════════════════════════════════ Stage 1: Pre-fine-tune on LEDGAR (60K legal provisions, 100 classes) Stage 2: Fine-tune on CUAD (41 classes) with Asymmetric Loss Usage: python train_classifier_v4.py # Full 2-stage pipeline python train_classifier_v4.py --stage 1 # Stage 1 only python train_classifier_v4.py --stage 2 --checkpoint ./stage1_ledgar_best # Stage 2 only Requirements: pip install transformers datasets scikit-learn accelerate torch Hardware: A100 80GB recommended (~4-6 hours total) """ import os import gc import argparse import json from collections import Counter from datetime import datetime import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from datasets import load_dataset, Dataset from sklearn.metrics import f1_score, precision_score, recall_score, classification_report from sklearn.model_selection import train_test_split from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, EarlyStoppingCallback, ) # ═══════════════════════════════════════════════════════════════ # CONFIGURATION # ═══════════════════════════════════════════════════════════════ BASE_MODEL = os.environ.get("BASE_MODEL", "microsoft/deberta-v3-large") MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512")) HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-deberta-v3-large") PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true" SEED = 42 CUAD_LABELS = [ "Document Name", "Parties", "Agreement Date", "Effective Date", "Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal", "Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity", "No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement", "Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control", "Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction", "Minimum Commitment", "Volume Restriction", "IP Ownership Assignment", "Joint IP Ownership", "License Grant", "Non-Transferable License", "Affiliate License-Licensor", "Affiliate License-Licensee", "Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License", "Source Code Escrow", "Post-Termination Services", "Audit Rights", "Uncapped Liability", "Cap on Liability", "Liquidated Damages", "Warranty Duration", "Insurance", "Covenant Not to Sue", "Third Party Beneficiary", "Other", ] NUM_CUAD_LABELS = len(CUAD_LABELS) # ═══════════════════════════════════════════════════════════════ # ASYMMETRIC LOSS (arxiv:2009.14119) # ═══════════════════════════════════════════════════════════════ class AsymmetricLoss(nn.Module): """Focal-style loss with asymmetric gamma for class imbalance.""" def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8, class_weights=None): super().__init__() self.gamma_pos = gamma_pos self.gamma_neg = gamma_neg self.clip = clip self.eps = eps if class_weights is not None: self.register_buffer('class_weights', torch.tensor(class_weights, dtype=torch.float32)) else: self.class_weights = None def forward(self, logits, targets): """Multi-class focal cross-entropy with class weights.""" if self.class_weights is not None: ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none') else: ce_loss = F.cross_entropy(logits, targets, reduction='none') probs = F.softmax(logits, dim=-1) p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1) focal_weight = (1 - p_t) ** self.gamma_neg loss = focal_weight * ce_loss return loss.mean() # ═══════════════════════════════════════════════════════════════ # CUSTOM TRAINER # ═══════════════════════════════════════════════════════════════ class ASLTrainer(Trainer): def __init__(self, *args, asl_loss_fn=None, **kwargs): super().__init__(*args, **kwargs) self.asl = asl_loss_fn def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.logits if self.asl is not None: loss = self.asl(logits, labels) else: loss = F.cross_entropy(logits, labels) return (loss, outputs) if return_outputs else loss # ═══════════════════════════════════════════════════════════════ # METRICS # ═══════════════════════════════════════════════════════════════ def compute_metrics(eval_pred): logits, labels = eval_pred.predictions, eval_pred.label_ids preds = np.argmax(logits, axis=-1) return { "accuracy": (preds == labels).mean(), "micro_f1": f1_score(labels, preds, average="micro", zero_division=0), "macro_f1": f1_score(labels, preds, average="macro", zero_division=0), "weighted_f1": f1_score(labels, preds, average="weighted", zero_division=0), } # ═══════════════════════════════════════════════════════════════ # STAGE 1: LEDGAR # ═══════════════════════════════════════════════════════════════ def run_stage1(tokenizer, output_dir="./stage1_ledgar_best"): print("\n" + "=" * 60) print(" STAGE 1: Pre-fine-tune on LEDGAR (100 classes)") print("=" * 60) ledgar = load_dataset("coastalcph/lex_glue", "ledgar") num_labels = ledgar['train'].features['label'].num_classes print(f" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,}") print(f" Classes: {num_labels}") def preprocess(examples): tok = tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False) tok["labels"] = examples["label"] return tok tokenized = ledgar.map(preprocess, batched=True, remove_columns=ledgar["train"].column_names) model = AutoModelForSequenceClassification.from_pretrained( BASE_MODEL, num_labels=num_labels, problem_type="single_label_classification", ignore_mismatched_sizes=True, ) print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") args = TrainingArguments( output_dir="./stage1_ledgar", num_train_epochs=5, per_device_train_batch_size=8, per_device_eval_batch_size=16, gradient_accumulation_steps=4, learning_rate=2e-5, weight_decay=0.06, warmup_ratio=0.1, lr_scheduler_type="cosine", eval_strategy="epoch", save_strategy="epoch", save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="macro_f1", greater_is_better=True, bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8, fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, logging_strategy="steps", logging_steps=50, logging_first_step=True, disable_tqdm=True, report_to="none", dataloader_num_workers=2, seed=SEED, gradient_checkpointing=True, ) trainer = Trainer( model=model, args=args, train_dataset=tokenized["train"], eval_dataset=tokenized["validation"], processing_class=tokenizer, data_collator=DataCollatorWithPadding(tokenizer=tokenizer), compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], ) result = trainer.train() print(f"\n Stage 1 training loss: {result.training_loss:.4f}") test_metrics = trainer.evaluate(tokenized["test"]) print(f" Stage 1 test micro-F1: {test_metrics['eval_micro_f1']:.4f}") print(f" Stage 1 test macro-F1: {test_metrics['eval_macro_f1']:.4f}") trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) print(f" Saved to {output_dir}") del model, trainer torch.cuda.empty_cache() gc.collect() return output_dir # ═══════════════════════════════════════════════════════════════ # STAGE 2: CUAD # ═══════════════════════════════════════════════════════════════ def run_stage2(tokenizer, checkpoint_path, output_dir="./clauseguard-deberta-final"): print("\n" + "=" * 60) print(f" STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL") print("=" * 60) # Load and split CUAD cuad_raw = load_dataset( "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification", split="train" ) cuad_df = cuad_raw.to_pandas() unique_files = cuad_df['file_name'].unique() train_files, test_files = train_test_split(unique_files, test_size=0.2, random_state=SEED) val_files, test_files = train_test_split(test_files, test_size=0.5, random_state=SEED) splits = { "train": Dataset.from_pandas( cuad_df[cuad_df['file_name'].isin(train_files)].reset_index(drop=True) ), "val": Dataset.from_pandas( cuad_df[cuad_df['file_name'].isin(val_files)].reset_index(drop=True) ), "test": Dataset.from_pandas( cuad_df[cuad_df['file_name'].isin(test_files)].reset_index(drop=True) ), } for name, ds in splits.items(): print(f" {name}: {len(ds)} rows") def preprocess_cuad(examples): tok = tokenizer(examples["clause"], truncation=True, max_length=MAX_LENGTH, padding=False) tok["labels"] = examples["class_id"] return tok tok_splits = {} for name, ds in splits.items(): tok_splits[name] = ds.map(preprocess_cuad, batched=True, remove_columns=ds.column_names) # Load model from Stage 1 checkpoint model = AutoModelForSequenceClassification.from_pretrained( checkpoint_path, num_labels=NUM_CUAD_LABELS, ignore_mismatched_sizes=True, problem_type="single_label_classification", ) # Update label mapping model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)} model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)} print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") # Compute class weights train_counts = Counter(tok_splits["train"]["labels"]) total = sum(train_counts.values()) class_weights = [] for i in range(NUM_CUAD_LABELS): count = train_counts.get(i, 1) weight = min(10.0, total / (NUM_CUAD_LABELS * count)) class_weights.append(weight) asl = AsymmetricLoss(gamma_pos=0, gamma_neg=4, clip=0.05, class_weights=class_weights) if torch.cuda.is_available(): asl = asl.cuda() args = TrainingArguments( output_dir="./stage2_cuad", num_train_epochs=20, per_device_train_batch_size=8, per_device_eval_batch_size=16, gradient_accumulation_steps=4, learning_rate=1e-5, weight_decay=0.06, warmup_ratio=0.1, lr_scheduler_type="cosine", eval_strategy="epoch", save_strategy="epoch", save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="macro_f1", greater_is_better=True, bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8, fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, logging_strategy="steps", logging_steps=25, logging_first_step=True, disable_tqdm=True, report_to="none", push_to_hub=PUSH_TO_HUB, hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None, dataloader_num_workers=2, seed=SEED, gradient_checkpointing=True, ) trainer = ASLTrainer( model=model, args=args, asl_loss_fn=asl, train_dataset=tok_splits["train"], eval_dataset=tok_splits["val"], processing_class=tokenizer, data_collator=DataCollatorWithPadding(tokenizer=tokenizer), compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], ) result = trainer.train() print(f"\n Stage 2 training loss: {result.training_loss:.4f}") # Evaluate test_metrics = trainer.evaluate(tok_splits["test"]) print(f"\n{'='*60}") print(f" CUAD TEST RESULTS") print(f"{'='*60}") print(f" Accuracy: {test_metrics['eval_accuracy']:.4f}") print(f" Micro-F1: {test_metrics['eval_micro_f1']:.4f}") print(f" Macro-F1: {test_metrics['eval_macro_f1']:.4f}") print(f" Weighted-F1: {test_metrics['eval_weighted_f1']:.4f}") # Full report preds_out = trainer.predict(tok_splits["test"]) preds = np.argmax(preds_out.predictions, axis=-1) labels = preds_out.label_ids present = sorted(set(labels) | set(preds)) names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f"Class-{i}" for i in present] print("\n" + classification_report(labels, preds, labels=present, target_names=names, zero_division=0, digits=4)) # Save trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) if PUSH_TO_HUB: trainer.push_to_hub( commit_message=( f"ClauseGuard v4: DeBERTa-v3-large LEDGAR→CUAD + ASL | " f"micro-F1={test_metrics['eval_micro_f1']:.4f} " f"macro-F1={test_metrics['eval_macro_f1']:.4f}" ) ) print(f"\n Pushed to https://huggingface.co/{HUB_MODEL_ID}") # Save test results results_path = os.path.join(output_dir, "test_results.json") with open(results_path, "w") as f: json.dump({ "model": HUB_MODEL_ID, "base_model": BASE_MODEL, "max_length": MAX_LENGTH, "stage1_dataset": "coastalcph/lex_glue (ledgar)", "stage2_dataset": "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification", "test_results": {k: float(v) for k, v in test_metrics.items() if isinstance(v, (int, float))}, "timestamp": datetime.now().isoformat(), }, f, indent=2) return output_dir # ═══════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════ def main(): parser = argparse.ArgumentParser(description="ClauseGuard v4 Training") parser.add_argument("--stage", type=int, default=0, help="Run specific stage (1 or 2). Default: both") parser.add_argument("--checkpoint", type=str, default="./stage1_ledgar_best", help="Stage 1 checkpoint path for Stage 2") args = parser.parse_args() print(f"🛡️ ClauseGuard v4 Training") print(f" Model: {BASE_MODEL}") print(f" Max length: {MAX_LENGTH}") print(f" Hub: {HUB_MODEL_ID}") if torch.cuda.is_available(): print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if args.stage in (0, 1): checkpoint = run_stage1(tokenizer) else: checkpoint = args.checkpoint if args.stage in (0, 2): run_stage2(tokenizer, checkpoint) print("\n✅ Training complete!") if __name__ == "__main__": main()