Spaces:
Running
Running
| """ | |
| ClauseGuard v4 β 2-Stage DeBERTa-v3-large Training Script | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Stage 1: Pre-fine-tune on LEDGAR (60K legal provisions, 100 classes) | |
| Stage 2: Fine-tune on CUAD (41 classes) with Asymmetric Loss | |
| Usage: | |
| python train_classifier_v4.py # Full 2-stage pipeline | |
| python train_classifier_v4.py --stage 1 # Stage 1 only | |
| python train_classifier_v4.py --stage 2 --checkpoint ./stage1_ledgar_best # Stage 2 only | |
| Requirements: | |
| pip install transformers datasets scikit-learn accelerate torch | |
| Hardware: A100 80GB recommended (~4-6 hours total) | |
| """ | |
| import os | |
| import gc | |
| import argparse | |
| import json | |
| from collections import Counter | |
| from datetime import datetime | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from datasets import load_dataset, Dataset | |
| from sklearn.metrics import f1_score, precision_score, recall_score, classification_report | |
| from sklearn.model_selection import train_test_split | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| Trainer, | |
| TrainingArguments, | |
| EarlyStoppingCallback, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIGURATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "microsoft/deberta-v3-large") | |
| MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512")) | |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-deberta-v3-large") | |
| PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true" | |
| SEED = 42 | |
| CUAD_LABELS = [ | |
| "Document Name", "Parties", "Agreement Date", "Effective Date", | |
| "Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal", | |
| "Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity", | |
| "No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement", | |
| "Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control", | |
| "Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction", | |
| "Minimum Commitment", "Volume Restriction", "IP Ownership Assignment", | |
| "Joint IP Ownership", "License Grant", "Non-Transferable License", | |
| "Affiliate License-Licensor", "Affiliate License-Licensee", | |
| "Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License", | |
| "Source Code Escrow", "Post-Termination Services", "Audit Rights", | |
| "Uncapped Liability", "Cap on Liability", "Liquidated Damages", | |
| "Warranty Duration", "Insurance", "Covenant Not to Sue", | |
| "Third Party Beneficiary", "Other", | |
| ] | |
| NUM_CUAD_LABELS = len(CUAD_LABELS) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ASYMMETRIC LOSS (arxiv:2009.14119) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AsymmetricLoss(nn.Module): | |
| """Focal-style loss with asymmetric gamma for class imbalance.""" | |
| def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8, | |
| class_weights=None): | |
| super().__init__() | |
| self.gamma_pos = gamma_pos | |
| self.gamma_neg = gamma_neg | |
| self.clip = clip | |
| self.eps = eps | |
| if class_weights is not None: | |
| self.register_buffer('class_weights', | |
| torch.tensor(class_weights, dtype=torch.float32)) | |
| else: | |
| self.class_weights = None | |
| def forward(self, logits, targets): | |
| """Multi-class focal cross-entropy with class weights.""" | |
| if self.class_weights is not None: | |
| ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, | |
| reduction='none') | |
| else: | |
| ce_loss = F.cross_entropy(logits, targets, reduction='none') | |
| probs = F.softmax(logits, dim=-1) | |
| p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1) | |
| focal_weight = (1 - p_t) ** self.gamma_neg | |
| loss = focal_weight * ce_loss | |
| return loss.mean() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CUSTOM TRAINER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ASLTrainer(Trainer): | |
| def __init__(self, *args, asl_loss_fn=None, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.asl = asl_loss_fn | |
| def compute_loss(self, model, inputs, return_outputs=False, | |
| num_items_in_batch=None): | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| if self.asl is not None: | |
| loss = self.asl(logits, labels) | |
| else: | |
| loss = F.cross_entropy(logits, labels) | |
| return (loss, outputs) if return_outputs else loss | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # METRICS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred.predictions, eval_pred.label_ids | |
| preds = np.argmax(logits, axis=-1) | |
| return { | |
| "accuracy": (preds == labels).mean(), | |
| "micro_f1": f1_score(labels, preds, average="micro", zero_division=0), | |
| "macro_f1": f1_score(labels, preds, average="macro", zero_division=0), | |
| "weighted_f1": f1_score(labels, preds, average="weighted", zero_division=0), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STAGE 1: LEDGAR | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_stage1(tokenizer, output_dir="./stage1_ledgar_best"): | |
| print("\n" + "=" * 60) | |
| print(" STAGE 1: Pre-fine-tune on LEDGAR (100 classes)") | |
| print("=" * 60) | |
| ledgar = load_dataset("coastalcph/lex_glue", "ledgar") | |
| num_labels = ledgar['train'].features['label'].num_classes | |
| print(f" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,}") | |
| print(f" Classes: {num_labels}") | |
| def preprocess(examples): | |
| tok = tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, | |
| padding=False) | |
| tok["labels"] = examples["label"] | |
| return tok | |
| tokenized = ledgar.map(preprocess, batched=True, | |
| remove_columns=ledgar["train"].column_names) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL, num_labels=num_labels, | |
| problem_type="single_label_classification", | |
| ignore_mismatched_sizes=True, | |
| ) | |
| print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| args = TrainingArguments( | |
| output_dir="./stage1_ledgar", | |
| num_train_epochs=5, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=16, | |
| gradient_accumulation_steps=4, | |
| learning_rate=2e-5, | |
| weight_decay=0.06, | |
| warmup_ratio=0.1, | |
| lr_scheduler_type="cosine", | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="macro_f1", | |
| greater_is_better=True, | |
| bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8, | |
| fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, | |
| logging_strategy="steps", | |
| logging_steps=50, | |
| logging_first_step=True, | |
| disable_tqdm=True, | |
| report_to="none", | |
| dataloader_num_workers=2, | |
| seed=SEED, | |
| gradient_checkpointing=True, | |
| ) | |
| trainer = Trainer( | |
| model=model, args=args, | |
| train_dataset=tokenized["train"], | |
| eval_dataset=tokenized["validation"], | |
| processing_class=tokenizer, | |
| data_collator=DataCollatorWithPadding(tokenizer=tokenizer), | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], | |
| ) | |
| result = trainer.train() | |
| print(f"\n Stage 1 training loss: {result.training_loss:.4f}") | |
| test_metrics = trainer.evaluate(tokenized["test"]) | |
| print(f" Stage 1 test micro-F1: {test_metrics['eval_micro_f1']:.4f}") | |
| print(f" Stage 1 test macro-F1: {test_metrics['eval_macro_f1']:.4f}") | |
| trainer.save_model(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| print(f" Saved to {output_dir}") | |
| del model, trainer | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return output_dir | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STAGE 2: CUAD | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_stage2(tokenizer, checkpoint_path, output_dir="./clauseguard-deberta-final"): | |
| print("\n" + "=" * 60) | |
| print(f" STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL") | |
| print("=" * 60) | |
| # Load and split CUAD | |
| cuad_raw = load_dataset( | |
| "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification", | |
| split="train" | |
| ) | |
| cuad_df = cuad_raw.to_pandas() | |
| unique_files = cuad_df['file_name'].unique() | |
| train_files, test_files = train_test_split(unique_files, test_size=0.2, | |
| random_state=SEED) | |
| val_files, test_files = train_test_split(test_files, test_size=0.5, | |
| random_state=SEED) | |
| splits = { | |
| "train": Dataset.from_pandas( | |
| cuad_df[cuad_df['file_name'].isin(train_files)].reset_index(drop=True) | |
| ), | |
| "val": Dataset.from_pandas( | |
| cuad_df[cuad_df['file_name'].isin(val_files)].reset_index(drop=True) | |
| ), | |
| "test": Dataset.from_pandas( | |
| cuad_df[cuad_df['file_name'].isin(test_files)].reset_index(drop=True) | |
| ), | |
| } | |
| for name, ds in splits.items(): | |
| print(f" {name}: {len(ds)} rows") | |
| def preprocess_cuad(examples): | |
| tok = tokenizer(examples["clause"], truncation=True, max_length=MAX_LENGTH, | |
| padding=False) | |
| tok["labels"] = examples["class_id"] | |
| return tok | |
| tok_splits = {} | |
| for name, ds in splits.items(): | |
| tok_splits[name] = ds.map(preprocess_cuad, batched=True, | |
| remove_columns=ds.column_names) | |
| # Load model from Stage 1 checkpoint | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| checkpoint_path, | |
| num_labels=NUM_CUAD_LABELS, | |
| ignore_mismatched_sizes=True, | |
| problem_type="single_label_classification", | |
| ) | |
| # Update label mapping | |
| model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)} | |
| model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)} | |
| print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # Compute class weights | |
| train_counts = Counter(tok_splits["train"]["labels"]) | |
| total = sum(train_counts.values()) | |
| class_weights = [] | |
| for i in range(NUM_CUAD_LABELS): | |
| count = train_counts.get(i, 1) | |
| weight = min(10.0, total / (NUM_CUAD_LABELS * count)) | |
| class_weights.append(weight) | |
| asl = AsymmetricLoss(gamma_pos=0, gamma_neg=4, clip=0.05, | |
| class_weights=class_weights) | |
| if torch.cuda.is_available(): | |
| asl = asl.cuda() | |
| args = TrainingArguments( | |
| output_dir="./stage2_cuad", | |
| num_train_epochs=20, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=16, | |
| gradient_accumulation_steps=4, | |
| learning_rate=1e-5, | |
| weight_decay=0.06, | |
| warmup_ratio=0.1, | |
| lr_scheduler_type="cosine", | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=3, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="macro_f1", | |
| greater_is_better=True, | |
| bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8, | |
| fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, | |
| logging_strategy="steps", | |
| logging_steps=25, | |
| logging_first_step=True, | |
| disable_tqdm=True, | |
| report_to="none", | |
| push_to_hub=PUSH_TO_HUB, | |
| hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None, | |
| dataloader_num_workers=2, | |
| seed=SEED, | |
| gradient_checkpointing=True, | |
| ) | |
| trainer = ASLTrainer( | |
| model=model, args=args, | |
| asl_loss_fn=asl, | |
| train_dataset=tok_splits["train"], | |
| eval_dataset=tok_splits["val"], | |
| processing_class=tokenizer, | |
| data_collator=DataCollatorWithPadding(tokenizer=tokenizer), | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], | |
| ) | |
| result = trainer.train() | |
| print(f"\n Stage 2 training loss: {result.training_loss:.4f}") | |
| # Evaluate | |
| test_metrics = trainer.evaluate(tok_splits["test"]) | |
| print(f"\n{'='*60}") | |
| print(f" CUAD TEST RESULTS") | |
| print(f"{'='*60}") | |
| print(f" Accuracy: {test_metrics['eval_accuracy']:.4f}") | |
| print(f" Micro-F1: {test_metrics['eval_micro_f1']:.4f}") | |
| print(f" Macro-F1: {test_metrics['eval_macro_f1']:.4f}") | |
| print(f" Weighted-F1: {test_metrics['eval_weighted_f1']:.4f}") | |
| # Full report | |
| preds_out = trainer.predict(tok_splits["test"]) | |
| preds = np.argmax(preds_out.predictions, axis=-1) | |
| labels = preds_out.label_ids | |
| present = sorted(set(labels) | set(preds)) | |
| names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f"Class-{i}" for i in present] | |
| print("\n" + classification_report(labels, preds, labels=present, | |
| target_names=names, zero_division=0, digits=4)) | |
| # Save | |
| trainer.save_model(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| if PUSH_TO_HUB: | |
| trainer.push_to_hub( | |
| commit_message=( | |
| f"ClauseGuard v4: DeBERTa-v3-large LEDGARβCUAD + ASL | " | |
| f"micro-F1={test_metrics['eval_micro_f1']:.4f} " | |
| f"macro-F1={test_metrics['eval_macro_f1']:.4f}" | |
| ) | |
| ) | |
| print(f"\n Pushed to https://huggingface.co/{HUB_MODEL_ID}") | |
| # Save test results | |
| results_path = os.path.join(output_dir, "test_results.json") | |
| with open(results_path, "w") as f: | |
| json.dump({ | |
| "model": HUB_MODEL_ID, | |
| "base_model": BASE_MODEL, | |
| "max_length": MAX_LENGTH, | |
| "stage1_dataset": "coastalcph/lex_glue (ledgar)", | |
| "stage2_dataset": "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification", | |
| "test_results": {k: float(v) for k, v in test_metrics.items() | |
| if isinstance(v, (int, float))}, | |
| "timestamp": datetime.now().isoformat(), | |
| }, f, indent=2) | |
| return output_dir | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="ClauseGuard v4 Training") | |
| parser.add_argument("--stage", type=int, default=0, | |
| help="Run specific stage (1 or 2). Default: both") | |
| parser.add_argument("--checkpoint", type=str, default="./stage1_ledgar_best", | |
| help="Stage 1 checkpoint path for Stage 2") | |
| args = parser.parse_args() | |
| print(f"π‘οΈ ClauseGuard v4 Training") | |
| print(f" Model: {BASE_MODEL}") | |
| print(f" Max length: {MAX_LENGTH}") | |
| print(f" Hub: {HUB_MODEL_ID}") | |
| if torch.cuda.is_available(): | |
| print(f" GPU: {torch.cuda.get_device_name(0)}") | |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if args.stage in (0, 1): | |
| checkpoint = run_stage1(tokenizer) | |
| else: | |
| checkpoint = args.checkpoint | |
| if args.stage in (0, 2): | |
| run_stage2(tokenizer, checkpoint) | |
| print("\nβ Training complete!") | |
| if __name__ == "__main__": | |
| main() | |