ClauseGuard / ml /train_classifier.py
gaurv007's picture
Fix: cast labels to float32 for BCEWithLogitsLoss compatibility
597978a verified
"""
ClauseGuard β€” Fine-tune Legal-BERT on CLAUDETTE/LexGLUE unfair_tos
Multi-label classification (8 unfair clause categories)
Compatible with: Transformers 5.6.x, Datasets 4.8.x (April 2026)
"""
import os
import numpy as np
import torch
from datasets import load_dataset, Sequence, Value
from sklearn.metrics import f1_score, precision_score, recall_score
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
EarlyStoppingCallback,
)
# ─── Config ───
MODEL_NAME = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased")
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./clauseguard-model")
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-legal-bert")
PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
NUM_LABELS = 8
MAX_LENGTH = 512
LABEL_NAMES = [
"Limitation of liability",
"Unilateral termination",
"Unilateral change",
"Content removal",
"Contract by using",
"Choice of law",
"Jurisdiction",
"Arbitration",
]
print(f"ClauseGuard Model Training")
print(f" Base model: {MODEL_NAME}")
print(f" Output: {OUTPUT_DIR}")
print(f" Push to Hub: {PUSH_TO_HUB} -> {HUB_MODEL_ID}")
# ─── 1. Load Dataset ───
print("Loading coastalcph/lex_glue (unfair_tos)...")
dataset = load_dataset("coastalcph/lex_glue", "unfair_tos")
print(f" Train: {len(dataset['train'])} | Val: {len(dataset['validation'])} | Test: {len(dataset['test'])}")
# ─── 2. Load Model + Tokenizer ───
print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
config = AutoConfig.from_pretrained(
MODEL_NAME,
num_labels=NUM_LABELS,
problem_type="multi_label_classification",
id2label={str(i): n for i, n in enumerate(LABEL_NAMES)},
label2id={n: i for i, n in enumerate(LABEL_NAMES)},
)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
config=config,
ignore_mismatched_sizes=True,
)
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# ─── 3. Preprocess ───
def preprocess(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
max_length=MAX_LENGTH,
padding=False,
)
batch_labels = []
for lbls in examples["labels"]:
vec = [0.0] * NUM_LABELS
for l in lbls:
vec[l] = 1.0
batch_labels.append(vec)
tokenized["labels"] = batch_labels
return tokenized
print("Tokenizing dataset...")
tokenized_ds = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)
# Critical: cast labels to float32 for BCEWithLogitsLoss (datasets default is int64)
for split in tokenized_ds:
tokenized_ds[split] = tokenized_ds[split].cast_column("labels", Sequence(Value("float32")))
tokenized_ds.set_format("torch")
# ─── 4. Metrics ───
def compute_metrics(eval_pred):
logits, labels = eval_pred.predictions, eval_pred.label_ids
probs = 1 / (1 + np.exp(-logits))
preds = (probs > 0.5).astype(int)
labels = labels.astype(int)
micro_f1 = f1_score(labels, preds, average="micro", zero_division=0)
macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
micro_p = precision_score(labels, preds, average="micro", zero_division=0)
micro_r = recall_score(labels, preds, average="micro", zero_division=0)
per_class = f1_score(labels, preds, average=None, zero_division=0)
class_metrics = {f"f1_{LABEL_NAMES[i][:15]}": float(per_class[i]) for i in range(NUM_LABELS)}
return {
"micro_f1": micro_f1,
"macro_f1": macro_f1,
"precision": micro_p,
"recall": micro_r,
**class_metrics,
}
# ─── 5. Training ───
print("Starting training...")
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=20,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
learning_rate=3e-5,
weight_decay=0.01,
warmup_ratio=0.1,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="macro_f1",
greater_is_better=True,
fp16=torch.cuda.is_available(),
bf16=False,
logging_strategy="steps",
logging_steps=25,
logging_first_step=True,
disable_tqdm=True,
report_to="none",
push_to_hub=PUSH_TO_HUB,
hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None,
seed=42,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds["train"],
eval_dataset=tokenized_ds["validation"],
processing_class=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
train_result = trainer.train()
print(f"Training complete! Loss: {train_result.training_loss:.4f}")
# ─── 6. Evaluate ───
print("Evaluating on test set...")
test_results = trainer.evaluate(tokenized_ds["test"])
print(f" Test micro-F1: {test_results.get('eval_micro_f1', 0):.4f}")
print(f" Test macro-F1: {test_results.get('eval_macro_f1', 0):.4f}")
print(f" Test precision: {test_results.get('eval_precision', 0):.4f}")
print(f" Test recall: {test_results.get('eval_recall', 0):.4f}")
# ─── 7. Save ───
final_dir = f"{OUTPUT_DIR}/final"
trainer.save_model(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"Model saved to {final_dir}")
if PUSH_TO_HUB:
print(f"Pushing to Hub: {HUB_MODEL_ID}")
trainer.push_to_hub(commit_message="ClauseGuard Legal-BERT fine-tuned on unfair_tos")
print("Pushed successfully!")
print("Done!")