ClauseGuard / ml /train_classifier_v4.py
gaurv007's picture
Add ClauseGuard v4 training script (DeBERTa-v3-large + LEDGAR + CUAD + ASL)
1dc0b52 verified
"""
ClauseGuard v4 β€” 2-Stage DeBERTa-v3-large Training Script
═══════════════════════════════════════════════════════════
Stage 1: Pre-fine-tune on LEDGAR (60K legal provisions, 100 classes)
Stage 2: Fine-tune on CUAD (41 classes) with Asymmetric Loss
Usage:
python train_classifier_v4.py # Full 2-stage pipeline
python train_classifier_v4.py --stage 1 # Stage 1 only
python train_classifier_v4.py --stage 2 --checkpoint ./stage1_ledgar_best # Stage 2 only
Requirements:
pip install transformers datasets scikit-learn accelerate torch
Hardware: A100 80GB recommended (~4-6 hours total)
"""
import os
import gc
import argparse
import json
from collections import Counter
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
from sklearn.model_selection import train_test_split
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
EarlyStoppingCallback,
)
# ═══════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════
BASE_MODEL = os.environ.get("BASE_MODEL", "microsoft/deberta-v3-large")
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512"))
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-deberta-v3-large")
PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
SEED = 42
CUAD_LABELS = [
"Document Name", "Parties", "Agreement Date", "Effective Date",
"Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
"Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
"No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
"Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
"Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
"Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
"Joint IP Ownership", "License Grant", "Non-Transferable License",
"Affiliate License-Licensor", "Affiliate License-Licensee",
"Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
"Source Code Escrow", "Post-Termination Services", "Audit Rights",
"Uncapped Liability", "Cap on Liability", "Liquidated Damages",
"Warranty Duration", "Insurance", "Covenant Not to Sue",
"Third Party Beneficiary", "Other",
]
NUM_CUAD_LABELS = len(CUAD_LABELS)
# ═══════════════════════════════════════════════════════════════
# ASYMMETRIC LOSS (arxiv:2009.14119)
# ═══════════════════════════════════════════════════════════════
class AsymmetricLoss(nn.Module):
"""Focal-style loss with asymmetric gamma for class imbalance."""
def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,
class_weights=None):
super().__init__()
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.clip = clip
self.eps = eps
if class_weights is not None:
self.register_buffer('class_weights',
torch.tensor(class_weights, dtype=torch.float32))
else:
self.class_weights = None
def forward(self, logits, targets):
"""Multi-class focal cross-entropy with class weights."""
if self.class_weights is not None:
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights,
reduction='none')
else:
ce_loss = F.cross_entropy(logits, targets, reduction='none')
probs = F.softmax(logits, dim=-1)
p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
focal_weight = (1 - p_t) ** self.gamma_neg
loss = focal_weight * ce_loss
return loss.mean()
# ═══════════════════════════════════════════════════════════════
# CUSTOM TRAINER
# ═══════════════════════════════════════════════════════════════
class ASLTrainer(Trainer):
def __init__(self, *args, asl_loss_fn=None, **kwargs):
super().__init__(*args, **kwargs)
self.asl = asl_loss_fn
def compute_loss(self, model, inputs, return_outputs=False,
num_items_in_batch=None):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
if self.asl is not None:
loss = self.asl(logits, labels)
else:
loss = F.cross_entropy(logits, labels)
return (loss, outputs) if return_outputs else loss
# ═══════════════════════════════════════════════════════════════
# METRICS
# ═══════════════════════════════════════════════════════════════
def compute_metrics(eval_pred):
logits, labels = eval_pred.predictions, eval_pred.label_ids
preds = np.argmax(logits, axis=-1)
return {
"accuracy": (preds == labels).mean(),
"micro_f1": f1_score(labels, preds, average="micro", zero_division=0),
"macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
"weighted_f1": f1_score(labels, preds, average="weighted", zero_division=0),
}
# ═══════════════════════════════════════════════════════════════
# STAGE 1: LEDGAR
# ═══════════════════════════════════════════════════════════════
def run_stage1(tokenizer, output_dir="./stage1_ledgar_best"):
print("\n" + "=" * 60)
print(" STAGE 1: Pre-fine-tune on LEDGAR (100 classes)")
print("=" * 60)
ledgar = load_dataset("coastalcph/lex_glue", "ledgar")
num_labels = ledgar['train'].features['label'].num_classes
print(f" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,}")
print(f" Classes: {num_labels}")
def preprocess(examples):
tok = tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH,
padding=False)
tok["labels"] = examples["label"]
return tok
tokenized = ledgar.map(preprocess, batched=True,
remove_columns=ledgar["train"].column_names)
model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL, num_labels=num_labels,
problem_type="single_label_classification",
ignore_mismatched_sizes=True,
)
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
args = TrainingArguments(
output_dir="./stage1_ledgar",
num_train_epochs=5,
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
gradient_accumulation_steps=4,
learning_rate=2e-5,
weight_decay=0.06,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="macro_f1",
greater_is_better=True,
bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
logging_strategy="steps",
logging_steps=50,
logging_first_step=True,
disable_tqdm=True,
report_to="none",
dataloader_num_workers=2,
seed=SEED,
gradient_checkpointing=True,
)
trainer = Trainer(
model=model, args=args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
processing_class=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
result = trainer.train()
print(f"\n Stage 1 training loss: {result.training_loss:.4f}")
test_metrics = trainer.evaluate(tokenized["test"])
print(f" Stage 1 test micro-F1: {test_metrics['eval_micro_f1']:.4f}")
print(f" Stage 1 test macro-F1: {test_metrics['eval_macro_f1']:.4f}")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print(f" Saved to {output_dir}")
del model, trainer
torch.cuda.empty_cache()
gc.collect()
return output_dir
# ═══════════════════════════════════════════════════════════════
# STAGE 2: CUAD
# ═══════════════════════════════════════════════════════════════
def run_stage2(tokenizer, checkpoint_path, output_dir="./clauseguard-deberta-final"):
print("\n" + "=" * 60)
print(f" STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL")
print("=" * 60)
# Load and split CUAD
cuad_raw = load_dataset(
"dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
split="train"
)
cuad_df = cuad_raw.to_pandas()
unique_files = cuad_df['file_name'].unique()
train_files, test_files = train_test_split(unique_files, test_size=0.2,
random_state=SEED)
val_files, test_files = train_test_split(test_files, test_size=0.5,
random_state=SEED)
splits = {
"train": Dataset.from_pandas(
cuad_df[cuad_df['file_name'].isin(train_files)].reset_index(drop=True)
),
"val": Dataset.from_pandas(
cuad_df[cuad_df['file_name'].isin(val_files)].reset_index(drop=True)
),
"test": Dataset.from_pandas(
cuad_df[cuad_df['file_name'].isin(test_files)].reset_index(drop=True)
),
}
for name, ds in splits.items():
print(f" {name}: {len(ds)} rows")
def preprocess_cuad(examples):
tok = tokenizer(examples["clause"], truncation=True, max_length=MAX_LENGTH,
padding=False)
tok["labels"] = examples["class_id"]
return tok
tok_splits = {}
for name, ds in splits.items():
tok_splits[name] = ds.map(preprocess_cuad, batched=True,
remove_columns=ds.column_names)
# Load model from Stage 1 checkpoint
model = AutoModelForSequenceClassification.from_pretrained(
checkpoint_path,
num_labels=NUM_CUAD_LABELS,
ignore_mismatched_sizes=True,
problem_type="single_label_classification",
)
# Update label mapping
model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Compute class weights
train_counts = Counter(tok_splits["train"]["labels"])
total = sum(train_counts.values())
class_weights = []
for i in range(NUM_CUAD_LABELS):
count = train_counts.get(i, 1)
weight = min(10.0, total / (NUM_CUAD_LABELS * count))
class_weights.append(weight)
asl = AsymmetricLoss(gamma_pos=0, gamma_neg=4, clip=0.05,
class_weights=class_weights)
if torch.cuda.is_available():
asl = asl.cuda()
args = TrainingArguments(
output_dir="./stage2_cuad",
num_train_epochs=20,
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
gradient_accumulation_steps=4,
learning_rate=1e-5,
weight_decay=0.06,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="macro_f1",
greater_is_better=True,
bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
logging_strategy="steps",
logging_steps=25,
logging_first_step=True,
disable_tqdm=True,
report_to="none",
push_to_hub=PUSH_TO_HUB,
hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None,
dataloader_num_workers=2,
seed=SEED,
gradient_checkpointing=True,
)
trainer = ASLTrainer(
model=model, args=args,
asl_loss_fn=asl,
train_dataset=tok_splits["train"],
eval_dataset=tok_splits["val"],
processing_class=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
result = trainer.train()
print(f"\n Stage 2 training loss: {result.training_loss:.4f}")
# Evaluate
test_metrics = trainer.evaluate(tok_splits["test"])
print(f"\n{'='*60}")
print(f" CUAD TEST RESULTS")
print(f"{'='*60}")
print(f" Accuracy: {test_metrics['eval_accuracy']:.4f}")
print(f" Micro-F1: {test_metrics['eval_micro_f1']:.4f}")
print(f" Macro-F1: {test_metrics['eval_macro_f1']:.4f}")
print(f" Weighted-F1: {test_metrics['eval_weighted_f1']:.4f}")
# Full report
preds_out = trainer.predict(tok_splits["test"])
preds = np.argmax(preds_out.predictions, axis=-1)
labels = preds_out.label_ids
present = sorted(set(labels) | set(preds))
names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f"Class-{i}" for i in present]
print("\n" + classification_report(labels, preds, labels=present,
target_names=names, zero_division=0, digits=4))
# Save
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
if PUSH_TO_HUB:
trainer.push_to_hub(
commit_message=(
f"ClauseGuard v4: DeBERTa-v3-large LEDGAR→CUAD + ASL | "
f"micro-F1={test_metrics['eval_micro_f1']:.4f} "
f"macro-F1={test_metrics['eval_macro_f1']:.4f}"
)
)
print(f"\n Pushed to https://huggingface.co/{HUB_MODEL_ID}")
# Save test results
results_path = os.path.join(output_dir, "test_results.json")
with open(results_path, "w") as f:
json.dump({
"model": HUB_MODEL_ID,
"base_model": BASE_MODEL,
"max_length": MAX_LENGTH,
"stage1_dataset": "coastalcph/lex_glue (ledgar)",
"stage2_dataset": "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
"test_results": {k: float(v) for k, v in test_metrics.items()
if isinstance(v, (int, float))},
"timestamp": datetime.now().isoformat(),
}, f, indent=2)
return output_dir
# ═══════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════
def main():
parser = argparse.ArgumentParser(description="ClauseGuard v4 Training")
parser.add_argument("--stage", type=int, default=0,
help="Run specific stage (1 or 2). Default: both")
parser.add_argument("--checkpoint", type=str, default="./stage1_ledgar_best",
help="Stage 1 checkpoint path for Stage 2")
args = parser.parse_args()
print(f"πŸ›‘οΈ ClauseGuard v4 Training")
print(f" Model: {BASE_MODEL}")
print(f" Max length: {MAX_LENGTH}")
print(f" Hub: {HUB_MODEL_ID}")
if torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if args.stage in (0, 1):
checkpoint = run_stage1(tokenizer)
else:
checkpoint = args.checkpoint
if args.stage in (0, 2):
run_stage2(tokenizer, checkpoint)
print("\nβœ… Training complete!")
if __name__ == "__main__":
main()