Spaces:
Sleeping
Sleeping
| """ | |
| ClauseGuard β Fine-tune Legal-BERT on CLAUDETTE/LexGLUE unfair_tos | |
| Multi-label classification (8 unfair clause categories) | |
| Compatible with: Transformers 5.6.x, Datasets 4.8.x (April 2026) | |
| """ | |
| import os | |
| import numpy as np | |
| import torch | |
| from datasets import load_dataset, Sequence, Value | |
| from sklearn.metrics import f1_score, precision_score, recall_score | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| Trainer, | |
| TrainingArguments, | |
| EarlyStoppingCallback, | |
| ) | |
| # βββ Config βββ | |
| MODEL_NAME = os.environ.get("BASE_MODEL", "nlpaueb/legal-bert-base-uncased") | |
| OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./clauseguard-model") | |
| HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-legal-bert") | |
| PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true" | |
| NUM_LABELS = 8 | |
| MAX_LENGTH = 512 | |
| LABEL_NAMES = [ | |
| "Limitation of liability", | |
| "Unilateral termination", | |
| "Unilateral change", | |
| "Content removal", | |
| "Contract by using", | |
| "Choice of law", | |
| "Jurisdiction", | |
| "Arbitration", | |
| ] | |
| print(f"ClauseGuard Model Training") | |
| print(f" Base model: {MODEL_NAME}") | |
| print(f" Output: {OUTPUT_DIR}") | |
| print(f" Push to Hub: {PUSH_TO_HUB} -> {HUB_MODEL_ID}") | |
| # βββ 1. Load Dataset βββ | |
| print("Loading coastalcph/lex_glue (unfair_tos)...") | |
| dataset = load_dataset("coastalcph/lex_glue", "unfair_tos") | |
| print(f" Train: {len(dataset['train'])} | Val: {len(dataset['validation'])} | Test: {len(dataset['test'])}") | |
| # βββ 2. Load Model + Tokenizer βββ | |
| print(f"Loading {MODEL_NAME}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| config = AutoConfig.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=NUM_LABELS, | |
| problem_type="multi_label_classification", | |
| id2label={str(i): n for i, n in enumerate(LABEL_NAMES)}, | |
| label2id={n: i for i, n in enumerate(LABEL_NAMES)}, | |
| ) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| MODEL_NAME, | |
| config=config, | |
| ignore_mismatched_sizes=True, | |
| ) | |
| print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # βββ 3. Preprocess βββ | |
| def preprocess(examples): | |
| tokenized = tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding=False, | |
| ) | |
| batch_labels = [] | |
| for lbls in examples["labels"]: | |
| vec = [0.0] * NUM_LABELS | |
| for l in lbls: | |
| vec[l] = 1.0 | |
| batch_labels.append(vec) | |
| tokenized["labels"] = batch_labels | |
| return tokenized | |
| print("Tokenizing dataset...") | |
| tokenized_ds = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names) | |
| # Critical: cast labels to float32 for BCEWithLogitsLoss (datasets default is int64) | |
| for split in tokenized_ds: | |
| tokenized_ds[split] = tokenized_ds[split].cast_column("labels", Sequence(Value("float32"))) | |
| tokenized_ds.set_format("torch") | |
| # βββ 4. Metrics βββ | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred.predictions, eval_pred.label_ids | |
| probs = 1 / (1 + np.exp(-logits)) | |
| preds = (probs > 0.5).astype(int) | |
| labels = labels.astype(int) | |
| micro_f1 = f1_score(labels, preds, average="micro", zero_division=0) | |
| macro_f1 = f1_score(labels, preds, average="macro", zero_division=0) | |
| micro_p = precision_score(labels, preds, average="micro", zero_division=0) | |
| micro_r = recall_score(labels, preds, average="micro", zero_division=0) | |
| per_class = f1_score(labels, preds, average=None, zero_division=0) | |
| class_metrics = {f"f1_{LABEL_NAMES[i][:15]}": float(per_class[i]) for i in range(NUM_LABELS)} | |
| return { | |
| "micro_f1": micro_f1, | |
| "macro_f1": macro_f1, | |
| "precision": micro_p, | |
| "recall": micro_r, | |
| **class_metrics, | |
| } | |
| # βββ 5. Training βββ | |
| print("Starting training...") | |
| training_args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=20, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=32, | |
| learning_rate=3e-5, | |
| weight_decay=0.01, | |
| warmup_ratio=0.1, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=3, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="macro_f1", | |
| greater_is_better=True, | |
| fp16=torch.cuda.is_available(), | |
| bf16=False, | |
| logging_strategy="steps", | |
| logging_steps=25, | |
| logging_first_step=True, | |
| disable_tqdm=True, | |
| report_to="none", | |
| push_to_hub=PUSH_TO_HUB, | |
| hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None, | |
| seed=42, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_ds["train"], | |
| eval_dataset=tokenized_ds["validation"], | |
| processing_class=tokenizer, | |
| data_collator=DataCollatorWithPadding(tokenizer=tokenizer), | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], | |
| ) | |
| train_result = trainer.train() | |
| print(f"Training complete! Loss: {train_result.training_loss:.4f}") | |
| # βββ 6. Evaluate βββ | |
| print("Evaluating on test set...") | |
| test_results = trainer.evaluate(tokenized_ds["test"]) | |
| print(f" Test micro-F1: {test_results.get('eval_micro_f1', 0):.4f}") | |
| print(f" Test macro-F1: {test_results.get('eval_macro_f1', 0):.4f}") | |
| print(f" Test precision: {test_results.get('eval_precision', 0):.4f}") | |
| print(f" Test recall: {test_results.get('eval_recall', 0):.4f}") | |
| # βββ 7. Save βββ | |
| final_dir = f"{OUTPUT_DIR}/final" | |
| trainer.save_model(final_dir) | |
| tokenizer.save_pretrained(final_dir) | |
| print(f"Model saved to {final_dir}") | |
| if PUSH_TO_HUB: | |
| print(f"Pushing to Hub: {HUB_MODEL_ID}") | |
| trainer.push_to_hub(commit_message="ClauseGuard Legal-BERT fine-tuned on unfair_tos") | |
| print("Pushed successfully!") | |
| print("Done!") | |