gaurv007 commited on
Commit
1dc0b52
Β·
verified Β·
1 Parent(s): c6e0514

Add ClauseGuard v4 training script (DeBERTa-v3-large + LEDGAR + CUAD + ASL)

Browse files
Files changed (1) hide show
  1. ml/train_classifier_v4.py +434 -0
ml/train_classifier_v4.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard v4 β€” 2-Stage DeBERTa-v3-large Training Script
3
+ ═══════════════════════════════════════════════════════════
4
+
5
+ Stage 1: Pre-fine-tune on LEDGAR (60K legal provisions, 100 classes)
6
+ Stage 2: Fine-tune on CUAD (41 classes) with Asymmetric Loss
7
+
8
+ Usage:
9
+ python train_classifier_v4.py # Full 2-stage pipeline
10
+ python train_classifier_v4.py --stage 1 # Stage 1 only
11
+ python train_classifier_v4.py --stage 2 --checkpoint ./stage1_ledgar_best # Stage 2 only
12
+
13
+ Requirements:
14
+ pip install transformers datasets scikit-learn accelerate torch
15
+
16
+ Hardware: A100 80GB recommended (~4-6 hours total)
17
+ """
18
+
19
+ import os
20
+ import gc
21
+ import argparse
22
+ import json
23
+ from collections import Counter
24
+ from datetime import datetime
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from datasets import load_dataset, Dataset
31
+ from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
32
+ from sklearn.model_selection import train_test_split
33
+ from transformers import (
34
+ AutoConfig,
35
+ AutoModelForSequenceClassification,
36
+ AutoTokenizer,
37
+ DataCollatorWithPadding,
38
+ Trainer,
39
+ TrainingArguments,
40
+ EarlyStoppingCallback,
41
+ )
42
+
43
+
44
+ # ═══════════════════════════════════════════════════════════════
45
+ # CONFIGURATION
46
+ # ═══════════════════════════════════════════════════════════════
47
+
48
+ BASE_MODEL = os.environ.get("BASE_MODEL", "microsoft/deberta-v3-large")
49
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512"))
50
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "gaurv007/clauseguard-deberta-v3-large")
51
+ PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
52
+ SEED = 42
53
+
54
+ CUAD_LABELS = [
55
+ "Document Name", "Parties", "Agreement Date", "Effective Date",
56
+ "Expiration Date", "Renewal Term", "Notice Period to Terminate Renewal",
57
+ "Governing Law", "Most Favored Nation", "Non-Compete", "Exclusivity",
58
+ "No-Solicit of Customers", "No-Solicit of Employees", "Non-Disparagement",
59
+ "Termination for Convenience", "ROFR/ROFO/ROFN", "Change of Control",
60
+ "Anti-Assignment", "Revenue/Profit Sharing", "Price Restriction",
61
+ "Minimum Commitment", "Volume Restriction", "IP Ownership Assignment",
62
+ "Joint IP Ownership", "License Grant", "Non-Transferable License",
63
+ "Affiliate License-Licensor", "Affiliate License-Licensee",
64
+ "Unlimited/All-You-Can-Eat License", "Irrevocable or Perpetual License",
65
+ "Source Code Escrow", "Post-Termination Services", "Audit Rights",
66
+ "Uncapped Liability", "Cap on Liability", "Liquidated Damages",
67
+ "Warranty Duration", "Insurance", "Covenant Not to Sue",
68
+ "Third Party Beneficiary", "Other",
69
+ ]
70
+ NUM_CUAD_LABELS = len(CUAD_LABELS)
71
+
72
+
73
+ # ═══════════════════════════════════════════════════════════════
74
+ # ASYMMETRIC LOSS (arxiv:2009.14119)
75
+ # ═══════════════════════════════════════════════════════════════
76
+
77
+ class AsymmetricLoss(nn.Module):
78
+ """Focal-style loss with asymmetric gamma for class imbalance."""
79
+
80
+ def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, eps=1e-8,
81
+ class_weights=None):
82
+ super().__init__()
83
+ self.gamma_pos = gamma_pos
84
+ self.gamma_neg = gamma_neg
85
+ self.clip = clip
86
+ self.eps = eps
87
+ if class_weights is not None:
88
+ self.register_buffer('class_weights',
89
+ torch.tensor(class_weights, dtype=torch.float32))
90
+ else:
91
+ self.class_weights = None
92
+
93
+ def forward(self, logits, targets):
94
+ """Multi-class focal cross-entropy with class weights."""
95
+ if self.class_weights is not None:
96
+ ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights,
97
+ reduction='none')
98
+ else:
99
+ ce_loss = F.cross_entropy(logits, targets, reduction='none')
100
+
101
+ probs = F.softmax(logits, dim=-1)
102
+ p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
103
+ focal_weight = (1 - p_t) ** self.gamma_neg
104
+ loss = focal_weight * ce_loss
105
+ return loss.mean()
106
+
107
+
108
+ # ═══════════════════════════════════════════════════════════════
109
+ # CUSTOM TRAINER
110
+ # ═══════════════════════════════════════════════════════════════
111
+
112
+ class ASLTrainer(Trainer):
113
+ def __init__(self, *args, asl_loss_fn=None, **kwargs):
114
+ super().__init__(*args, **kwargs)
115
+ self.asl = asl_loss_fn
116
+
117
+ def compute_loss(self, model, inputs, return_outputs=False,
118
+ num_items_in_batch=None):
119
+ labels = inputs.pop("labels")
120
+ outputs = model(**inputs)
121
+ logits = outputs.logits
122
+ if self.asl is not None:
123
+ loss = self.asl(logits, labels)
124
+ else:
125
+ loss = F.cross_entropy(logits, labels)
126
+ return (loss, outputs) if return_outputs else loss
127
+
128
+
129
+ # ═══════════════════════════════════════════════════════════════
130
+ # METRICS
131
+ # ═══════════════════════════════════════════════════════════════
132
+
133
+ def compute_metrics(eval_pred):
134
+ logits, labels = eval_pred.predictions, eval_pred.label_ids
135
+ preds = np.argmax(logits, axis=-1)
136
+ return {
137
+ "accuracy": (preds == labels).mean(),
138
+ "micro_f1": f1_score(labels, preds, average="micro", zero_division=0),
139
+ "macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
140
+ "weighted_f1": f1_score(labels, preds, average="weighted", zero_division=0),
141
+ }
142
+
143
+
144
+ # ═══════════════════════════════════════════════════════════════
145
+ # STAGE 1: LEDGAR
146
+ # ═══════════════════════════════════════════════════════════════
147
+
148
+ def run_stage1(tokenizer, output_dir="./stage1_ledgar_best"):
149
+ print("\n" + "=" * 60)
150
+ print(" STAGE 1: Pre-fine-tune on LEDGAR (100 classes)")
151
+ print("=" * 60)
152
+
153
+ ledgar = load_dataset("coastalcph/lex_glue", "ledgar")
154
+ num_labels = ledgar['train'].features['label'].num_classes
155
+ print(f" Train: {len(ledgar['train']):,} | Val: {len(ledgar['validation']):,}")
156
+ print(f" Classes: {num_labels}")
157
+
158
+ def preprocess(examples):
159
+ tok = tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH,
160
+ padding=False)
161
+ tok["labels"] = examples["label"]
162
+ return tok
163
+
164
+ tokenized = ledgar.map(preprocess, batched=True,
165
+ remove_columns=ledgar["train"].column_names)
166
+
167
+ model = AutoModelForSequenceClassification.from_pretrained(
168
+ BASE_MODEL, num_labels=num_labels,
169
+ problem_type="single_label_classification",
170
+ ignore_mismatched_sizes=True,
171
+ )
172
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
173
+
174
+ args = TrainingArguments(
175
+ output_dir="./stage1_ledgar",
176
+ num_train_epochs=5,
177
+ per_device_train_batch_size=8,
178
+ per_device_eval_batch_size=16,
179
+ gradient_accumulation_steps=4,
180
+ learning_rate=2e-5,
181
+ weight_decay=0.06,
182
+ warmup_ratio=0.1,
183
+ lr_scheduler_type="cosine",
184
+ eval_strategy="epoch",
185
+ save_strategy="epoch",
186
+ save_total_limit=2,
187
+ load_best_model_at_end=True,
188
+ metric_for_best_model="macro_f1",
189
+ greater_is_better=True,
190
+ bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
191
+ fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
192
+ logging_strategy="steps",
193
+ logging_steps=50,
194
+ logging_first_step=True,
195
+ disable_tqdm=True,
196
+ report_to="none",
197
+ dataloader_num_workers=2,
198
+ seed=SEED,
199
+ gradient_checkpointing=True,
200
+ )
201
+
202
+ trainer = Trainer(
203
+ model=model, args=args,
204
+ train_dataset=tokenized["train"],
205
+ eval_dataset=tokenized["validation"],
206
+ processing_class=tokenizer,
207
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
208
+ compute_metrics=compute_metrics,
209
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
210
+ )
211
+
212
+ result = trainer.train()
213
+ print(f"\n Stage 1 training loss: {result.training_loss:.4f}")
214
+
215
+ test_metrics = trainer.evaluate(tokenized["test"])
216
+ print(f" Stage 1 test micro-F1: {test_metrics['eval_micro_f1']:.4f}")
217
+ print(f" Stage 1 test macro-F1: {test_metrics['eval_macro_f1']:.4f}")
218
+
219
+ trainer.save_model(output_dir)
220
+ tokenizer.save_pretrained(output_dir)
221
+ print(f" Saved to {output_dir}")
222
+
223
+ del model, trainer
224
+ torch.cuda.empty_cache()
225
+ gc.collect()
226
+
227
+ return output_dir
228
+
229
+
230
+ # ═════════════════════════════════════════════════════════��═════
231
+ # STAGE 2: CUAD
232
+ # ═══════════════════════════════════════════════════════════════
233
+
234
+ def run_stage2(tokenizer, checkpoint_path, output_dir="./clauseguard-deberta-final"):
235
+ print("\n" + "=" * 60)
236
+ print(f" STAGE 2: Fine-tune on CUAD ({NUM_CUAD_LABELS} classes) with ASL")
237
+ print("=" * 60)
238
+
239
+ # Load and split CUAD
240
+ cuad_raw = load_dataset(
241
+ "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
242
+ split="train"
243
+ )
244
+ cuad_df = cuad_raw.to_pandas()
245
+
246
+ unique_files = cuad_df['file_name'].unique()
247
+ train_files, test_files = train_test_split(unique_files, test_size=0.2,
248
+ random_state=SEED)
249
+ val_files, test_files = train_test_split(test_files, test_size=0.5,
250
+ random_state=SEED)
251
+
252
+ splits = {
253
+ "train": Dataset.from_pandas(
254
+ cuad_df[cuad_df['file_name'].isin(train_files)].reset_index(drop=True)
255
+ ),
256
+ "val": Dataset.from_pandas(
257
+ cuad_df[cuad_df['file_name'].isin(val_files)].reset_index(drop=True)
258
+ ),
259
+ "test": Dataset.from_pandas(
260
+ cuad_df[cuad_df['file_name'].isin(test_files)].reset_index(drop=True)
261
+ ),
262
+ }
263
+
264
+ for name, ds in splits.items():
265
+ print(f" {name}: {len(ds)} rows")
266
+
267
+ def preprocess_cuad(examples):
268
+ tok = tokenizer(examples["clause"], truncation=True, max_length=MAX_LENGTH,
269
+ padding=False)
270
+ tok["labels"] = examples["class_id"]
271
+ return tok
272
+
273
+ tok_splits = {}
274
+ for name, ds in splits.items():
275
+ tok_splits[name] = ds.map(preprocess_cuad, batched=True,
276
+ remove_columns=ds.column_names)
277
+
278
+ # Load model from Stage 1 checkpoint
279
+ model = AutoModelForSequenceClassification.from_pretrained(
280
+ checkpoint_path,
281
+ num_labels=NUM_CUAD_LABELS,
282
+ ignore_mismatched_sizes=True,
283
+ problem_type="single_label_classification",
284
+ )
285
+
286
+ # Update label mapping
287
+ model.config.id2label = {str(i): name for i, name in enumerate(CUAD_LABELS)}
288
+ model.config.label2id = {name: i for i, name in enumerate(CUAD_LABELS)}
289
+
290
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
291
+
292
+ # Compute class weights
293
+ train_counts = Counter(tok_splits["train"]["labels"])
294
+ total = sum(train_counts.values())
295
+ class_weights = []
296
+ for i in range(NUM_CUAD_LABELS):
297
+ count = train_counts.get(i, 1)
298
+ weight = min(10.0, total / (NUM_CUAD_LABELS * count))
299
+ class_weights.append(weight)
300
+
301
+ asl = AsymmetricLoss(gamma_pos=0, gamma_neg=4, clip=0.05,
302
+ class_weights=class_weights)
303
+ if torch.cuda.is_available():
304
+ asl = asl.cuda()
305
+
306
+ args = TrainingArguments(
307
+ output_dir="./stage2_cuad",
308
+ num_train_epochs=20,
309
+ per_device_train_batch_size=8,
310
+ per_device_eval_batch_size=16,
311
+ gradient_accumulation_steps=4,
312
+ learning_rate=1e-5,
313
+ weight_decay=0.06,
314
+ warmup_ratio=0.1,
315
+ lr_scheduler_type="cosine",
316
+ eval_strategy="epoch",
317
+ save_strategy="epoch",
318
+ save_total_limit=3,
319
+ load_best_model_at_end=True,
320
+ metric_for_best_model="macro_f1",
321
+ greater_is_better=True,
322
+ bf16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8,
323
+ fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
324
+ logging_strategy="steps",
325
+ logging_steps=25,
326
+ logging_first_step=True,
327
+ disable_tqdm=True,
328
+ report_to="none",
329
+ push_to_hub=PUSH_TO_HUB,
330
+ hub_model_id=HUB_MODEL_ID if PUSH_TO_HUB else None,
331
+ dataloader_num_workers=2,
332
+ seed=SEED,
333
+ gradient_checkpointing=True,
334
+ )
335
+
336
+ trainer = ASLTrainer(
337
+ model=model, args=args,
338
+ asl_loss_fn=asl,
339
+ train_dataset=tok_splits["train"],
340
+ eval_dataset=tok_splits["val"],
341
+ processing_class=tokenizer,
342
+ data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
343
+ compute_metrics=compute_metrics,
344
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
345
+ )
346
+
347
+ result = trainer.train()
348
+ print(f"\n Stage 2 training loss: {result.training_loss:.4f}")
349
+
350
+ # Evaluate
351
+ test_metrics = trainer.evaluate(tok_splits["test"])
352
+ print(f"\n{'='*60}")
353
+ print(f" CUAD TEST RESULTS")
354
+ print(f"{'='*60}")
355
+ print(f" Accuracy: {test_metrics['eval_accuracy']:.4f}")
356
+ print(f" Micro-F1: {test_metrics['eval_micro_f1']:.4f}")
357
+ print(f" Macro-F1: {test_metrics['eval_macro_f1']:.4f}")
358
+ print(f" Weighted-F1: {test_metrics['eval_weighted_f1']:.4f}")
359
+
360
+ # Full report
361
+ preds_out = trainer.predict(tok_splits["test"])
362
+ preds = np.argmax(preds_out.predictions, axis=-1)
363
+ labels = preds_out.label_ids
364
+ present = sorted(set(labels) | set(preds))
365
+ names = [CUAD_LABELS[i] if i < len(CUAD_LABELS) else f"Class-{i}" for i in present]
366
+ print("\n" + classification_report(labels, preds, labels=present,
367
+ target_names=names, zero_division=0, digits=4))
368
+
369
+ # Save
370
+ trainer.save_model(output_dir)
371
+ tokenizer.save_pretrained(output_dir)
372
+
373
+ if PUSH_TO_HUB:
374
+ trainer.push_to_hub(
375
+ commit_message=(
376
+ f"ClauseGuard v4: DeBERTa-v3-large LEDGAR→CUAD + ASL | "
377
+ f"micro-F1={test_metrics['eval_micro_f1']:.4f} "
378
+ f"macro-F1={test_metrics['eval_macro_f1']:.4f}"
379
+ )
380
+ )
381
+ print(f"\n Pushed to https://huggingface.co/{HUB_MODEL_ID}")
382
+
383
+ # Save test results
384
+ results_path = os.path.join(output_dir, "test_results.json")
385
+ with open(results_path, "w") as f:
386
+ json.dump({
387
+ "model": HUB_MODEL_ID,
388
+ "base_model": BASE_MODEL,
389
+ "max_length": MAX_LENGTH,
390
+ "stage1_dataset": "coastalcph/lex_glue (ledgar)",
391
+ "stage2_dataset": "dvgodoy/CUAD_v1_Contract_Understanding_clause_classification",
392
+ "test_results": {k: float(v) for k, v in test_metrics.items()
393
+ if isinstance(v, (int, float))},
394
+ "timestamp": datetime.now().isoformat(),
395
+ }, f, indent=2)
396
+
397
+ return output_dir
398
+
399
+
400
+ # ═══════════════════════════════════════════════════════════════
401
+ # MAIN
402
+ # ═══════════════════════════════════════════════════════════════
403
+
404
+ def main():
405
+ parser = argparse.ArgumentParser(description="ClauseGuard v4 Training")
406
+ parser.add_argument("--stage", type=int, default=0,
407
+ help="Run specific stage (1 or 2). Default: both")
408
+ parser.add_argument("--checkpoint", type=str, default="./stage1_ledgar_best",
409
+ help="Stage 1 checkpoint path for Stage 2")
410
+ args = parser.parse_args()
411
+
412
+ print(f"πŸ›‘οΈ ClauseGuard v4 Training")
413
+ print(f" Model: {BASE_MODEL}")
414
+ print(f" Max length: {MAX_LENGTH}")
415
+ print(f" Hub: {HUB_MODEL_ID}")
416
+ if torch.cuda.is_available():
417
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
418
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
419
+
420
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
421
+
422
+ if args.stage in (0, 1):
423
+ checkpoint = run_stage1(tokenizer)
424
+ else:
425
+ checkpoint = args.checkpoint
426
+
427
+ if args.stage in (0, 2):
428
+ run_stage2(tokenizer, checkpoint)
429
+
430
+ print("\nβœ… Training complete!")
431
+
432
+
433
+ if __name__ == "__main__":
434
+ main()