""" ClauseGuard — Fine-tune Legal-BERT on CLAUDETTE/LexGLUE unfair_tos Multi-label classification (8 unfair clause categories) Compatible with: Transformers 5.6.x, Datasets 4.8.x (April 2026) """ import os import numpy as np import torch from datasets import load_dataset, Sequence, Value from sklearn.metrics import f1_score, precision_score, recall_score from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, EarlyStoppingCallback, ) # ─── Config ─── MODEL_NAME = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased") OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./clauseguard-model") HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-legal-bert") PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true" NUM_LABELS = 8 MAX_LENGTH = 512 LABEL_NAMES = [ "Limitation of liability", "Unilateral termination", "Unilateral change", "Content removal", "Contract by using", "Choice of law", "Jurisdiction", "Arbitration", ] print(f"ClauseGuard Model Training") print(f" Base model: {MODEL_NAME}") print(f" Output: {OUTPUT_DIR}") print(f" Push to Hub: {PUSH_TO_HUB} -> {HUB_MODEL_ID}") # ─── 1. Load Dataset ─── print("Loading coastalcph/lex_glue (unfair_tos)...") dataset = load_dataset("coastalcph/lex_glue", "unfair_tos") print(f" Train: {len(dataset['train'])} | Val: {len(dataset['validation'])} | Test: {len(dataset['test'])}") # ─── 2. Load Model + Tokenizer ─── print(f"Loading {MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) config = AutoConfig.from_pretrained( MODEL_NAME, num_labels=NUM_LABELS, problem_type="multi_label_classification", id2label={str(i): n for i, n in enumerate(LABEL_NAMES)}, label2id={n: i for i, n in enumerate(LABEL_NAMES)}, ) model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, config=config, ignore_mismatched_sizes=True, ) print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") # ─── 3. Preprocess ─── def preprocess(examples): tokenized = tokenizer( examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False, ) batch_labels = [] for lbls in examples["labels"]: vec = [0.0] * NUM_LABELS for l in lbls: vec[l] = 1.0 batch_labels.append(vec) tokenized["labels"] = batch_labels return tokenized print("Tokenizing dataset...") tokenized_ds = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names) # Critical: cast labels to float32 for BCEWithLogitsLoss (datasets default is int64) for split in tokenized_ds: tokenized_ds[split] = tokenized_ds[split].cast_column("labels", Sequence(Value("float32"))) tokenized_ds.set_format("torch") # ─── 4. Metrics ─── def compute_metrics(eval_pred): logits, labels = eval_pred.predictions, eval_pred.label_ids probs = 1 / (1 + np.exp(-logits)) preds = (probs > 0.5).astype(int) labels = labels.astype(int) micro_f1 = f1_score(labels, preds, average="micro", zero_division=0) macro_f1 = f1_score(labels, preds, average="macro", zero_division=0) micro_p = precision_score(labels, preds, average="micro", zero_division=0) micro_r = recall_score(labels, preds, average="micro", zero_division=0) per_class = f1_score(labels, preds, average=None, zero_division=0) class_metrics = {f"f1_{LABEL_NAMES[i][:15]}": float(per_class[i]) for i in range(NUM_LABELS)} return { "micro_f1": micro_f1, "macro_f1": macro_f1, "precision": micro_p, "recall": micro_r, **class_metrics, } # ─── 5. Training ─── print("Starting training...") training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=20, per_device_train_batch_size=16, per_device_eval_batch_size=32, learning_rate=3e-5, weight_decay=0.01, warmup_ratio=0.1, eval_strategy="epoch", save_strategy="epoch", save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="macro_f1", greater_is_better=True, fp16=torch.cuda.is_available(), bf16=False, logging_strategy="steps", logging_steps=25, logging_first_step=True, disable_tqdm=True, report_to="none", push_to_hub=PUSH_TO_HUB, hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None, seed=42, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_ds["train"], eval_dataset=tokenized_ds["validation"], processing_class=tokenizer, data_collator=DataCollatorWithPadding(tokenizer=tokenizer), compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], ) train_result = trainer.train() print(f"Training complete! Loss: {train_result.training_loss:.4f}") # ─── 6. Evaluate ─── print("Evaluating on test set...") test_results = trainer.evaluate(tokenized_ds["test"]) print(f" Test micro-F1: {test_results.get('eval_micro_f1', 0):.4f}") print(f" Test macro-F1: {test_results.get('eval_macro_f1', 0):.4f}") print(f" Test precision: {test_results.get('eval_precision', 0):.4f}") print(f" Test recall: {test_results.get('eval_recall', 0):.4f}") # ─── 7. Save ─── final_dir = f"{OUTPUT_DIR}/final" trainer.save_model(final_dir) tokenizer.save_pretrained(final_dir) print(f"Model saved to {final_dir}") if PUSH_TO_HUB: print(f"Pushing to Hub: {HUB_MODEL_ID}") trainer.push_to_hub(commit_message="ClauseGuard Legal-BERT fine-tuned on unfair_tos") print("Pushed successfully!") print("Done!")