Cludoy commited on
Commit
330841b
Β·
verified Β·
1 Parent(s): bb1c80c

Add train.py

Browse files
Files changed (1) hide show
  1. train.py +298 -0
train.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training pipeline for TinyBert-CNN Intent Classifier.
3
+ Features: discriminative fine-tuning, warmup+cosine LR, early stopping,
4
+ comprehensive per-class/epoch metric tracking.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import DataLoader
10
+ import pandas as pd
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import time
14
+ import json
15
+ import math
16
+ from sklearn.metrics import (
17
+ classification_report, confusion_matrix,
18
+ accuracy_score, precision_recall_fscore_support
19
+ )
20
+ import warnings
21
+ warnings.filterwarnings('ignore')
22
+
23
+ from TinyBert import IntentClassifier, IntentDataset
24
+
25
+ INTENT_NAMES = ['On-Topic Question', 'Off-Topic Question', 'Emotional-State', 'Pace-Related', 'Repeat/clarification']
26
+
27
+
28
+ # ─────────────────────────────────────────────────────────────────────
29
+ # UTILITIES
30
+ # ─────────────────────────────────────────────────────────────────────
31
+
32
+ class EarlyStopping:
33
+ def __init__(self, patience=3, min_delta=0.001, verbose=True):
34
+ self.patience = patience
35
+ self.min_delta = min_delta
36
+ self.verbose = verbose
37
+ self.counter = 0
38
+ self.best_loss = None
39
+ self.early_stop = False
40
+ self.best_epoch = 0
41
+
42
+ def __call__(self, val_loss, epoch):
43
+ if self.best_loss is None:
44
+ self.best_loss = val_loss
45
+ self.best_epoch = epoch
46
+ elif val_loss > self.best_loss - self.min_delta:
47
+ self.counter += 1
48
+ if self.verbose:
49
+ print(f" Early stopping counter: {self.counter}/{self.patience}")
50
+ if self.counter >= self.patience:
51
+ self.early_stop = True
52
+ if self.verbose:
53
+ print(f" [!] Early stopping triggered! Best epoch was {self.best_epoch}")
54
+ else:
55
+ self.best_loss = val_loss
56
+ self.best_epoch = epoch
57
+ self.counter = 0
58
+
59
+
60
+ class WarmupCosineScheduler:
61
+ def __init__(self, optimizer, warmup_steps, total_steps):
62
+ self.optimizer = optimizer
63
+ self.warmup_steps = warmup_steps
64
+ self.total_steps = total_steps
65
+ self.base_lrs = [pg['lr'] for pg in optimizer.param_groups]
66
+ self.current_step = 0
67
+
68
+ def step(self):
69
+ self.current_step += 1
70
+ if self.current_step <= self.warmup_steps:
71
+ scale = self.current_step / max(1, self.warmup_steps)
72
+ else:
73
+ progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
74
+ scale = 0.5 * (1.0 + math.cos(math.pi * progress))
75
+ for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
76
+ pg['lr'] = base_lr * scale
77
+
78
+
79
+ def load_data(train_path, val_path, test_path):
80
+ train_df = pd.read_csv(train_path)
81
+ val_df = pd.read_csv(val_path)
82
+ test_df = pd.read_csv(test_path)
83
+ return train_df, val_df, test_df
84
+
85
+
86
+ def compute_class_weights(labels, num_classes, device):
87
+ counts = np.bincount(labels, minlength=num_classes).astype(float)
88
+ counts[counts == 0] = 1.0
89
+ weights = 1.0 / counts
90
+ weights = weights / weights.sum() * num_classes
91
+ return torch.tensor(weights, dtype=torch.float32).to(device)
92
+
93
+
94
+ def evaluate_model_full(classifier, loader):
95
+ """Full evaluation returning all metrics."""
96
+ classifier.model.eval()
97
+ all_preds, all_labels = [], []
98
+ total_loss = 0
99
+ criterion = nn.CrossEntropyLoss()
100
+
101
+ with torch.no_grad():
102
+ for batch in loader:
103
+ input_ids = batch['input_ids'].to(classifier.device)
104
+ attention_mask = batch['attention_mask'].to(classifier.device)
105
+ labels = batch['labels'].to(classifier.device)
106
+ token_type_ids = batch.get('token_type_ids')
107
+ if token_type_ids is not None:
108
+ token_type_ids = token_type_ids.to(classifier.device)
109
+
110
+ logits = classifier.model(input_ids, attention_mask, token_type_ids=token_type_ids)
111
+ loss = criterion(logits, labels)
112
+ total_loss += loss.item() * labels.size(0)
113
+
114
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
115
+ all_preds.extend(preds)
116
+ all_labels.extend(labels.cpu().numpy())
117
+
118
+ n = len(all_labels)
119
+ avg_loss = total_loss / n
120
+ accuracy = accuracy_score(all_labels, all_preds)
121
+ precision, recall, f1, _ = precision_recall_fscore_support(
122
+ all_labels, all_preds, average='weighted', zero_division=0
123
+ )
124
+
125
+ return avg_loss, accuracy, precision, recall, f1, all_preds, all_labels
126
+
127
+
128
+ # ─────────────────────────────────────────────────────────────────────
129
+ # MAIN
130
+ # ─────────────────────────────────────────────────────────────────────
131
+
132
+ def main():
133
+ # ── Hyperparameters ─────────────────────────────────────────────
134
+ TRAIN_PATH = 'data/train.csv'
135
+ VAL_PATH = 'data/val.csv'
136
+ TEST_PATH = 'data/test.csv'
137
+ BATCH_SIZE = 16
138
+ EPOCHS = 20
139
+ BERT_LR = 2e-5 # Lower LR for BERT backbone
140
+ HEAD_LR = 1e-3 # Higher LR for CNN + FC head
141
+ WEIGHT_DECAY = 0.01
142
+ MAX_LENGTH = 128
143
+ PATIENCE = 5
144
+
145
+ hyperparams = {
146
+ 'batch_size': BATCH_SIZE,
147
+ 'epochs': EPOCHS,
148
+ 'bert_lr': BERT_LR,
149
+ 'head_lr': HEAD_LR,
150
+ 'weight_decay': WEIGHT_DECAY,
151
+ 'max_length': MAX_LENGTH,
152
+ 'patience': PATIENCE,
153
+ 'label_smoothing': 0.1
154
+ }
155
+
156
+ print("=" * 60)
157
+ print("TinyBert-CNN Multi-Input Model Training")
158
+ print("=" * 60)
159
+
160
+ start_time = time.time()
161
+
162
+ # ── Data ────────────────────────────────────────────────────────
163
+ train_df, val_df, test_df = load_data(TRAIN_PATH, VAL_PATH, TEST_PATH)
164
+ num_classes = train_df['label'].nunique()
165
+ print(f"Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)} | Classes: {num_classes}")
166
+
167
+ classifier = IntentClassifier(num_classes=num_classes)
168
+
169
+ train_dataset = IntentDataset(train_df.to_dict('records'), classifier.tokenizer, max_length=MAX_LENGTH)
170
+ val_dataset = IntentDataset(val_df.to_dict('records'), classifier.tokenizer, max_length=MAX_LENGTH)
171
+ test_dataset = IntentDataset(test_df.to_dict('records'), classifier.tokenizer, max_length=MAX_LENGTH)
172
+
173
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
174
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
175
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
176
+
177
+ # ── Optimizer with discriminative fine-tuning ───────────────────
178
+ class_weights = compute_class_weights(train_df['label'].values, num_classes, classifier.device)
179
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.1, weight=class_weights)
180
+
181
+ bert_params = list(classifier.model.bert.parameters())
182
+ head_params = [p for n, p in classifier.model.named_parameters() if not n.startswith('bert.')]
183
+
184
+ optimizer = torch.optim.AdamW([
185
+ {'params': bert_params, 'lr': BERT_LR},
186
+ {'params': head_params, 'lr': HEAD_LR}
187
+ ], weight_decay=WEIGHT_DECAY)
188
+
189
+ total_steps = len(train_loader) * EPOCHS
190
+ warmup_steps = int(total_steps * 0.1)
191
+ scheduler = WarmupCosineScheduler(optimizer, warmup_steps, total_steps)
192
+ early_stopping = EarlyStopping(patience=PATIENCE)
193
+
194
+ best_val_f1 = 0.0
195
+ best_model_path = "best_tinybert.pt"
196
+
197
+ # ── Training history ────────────────────────────────────────────
198
+ history = {
199
+ 'train_loss': [],
200
+ 'val_loss': [],
201
+ 'val_acc': [],
202
+ 'val_f1': []
203
+ }
204
+
205
+ # ── Training loop ──────────────────────────────────────────────
206
+ for epoch in range(EPOCHS):
207
+ classifier.model.train()
208
+ train_loss = 0
209
+ train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
210
+
211
+ for batch in train_pbar:
212
+ loss = classifier.train_step(batch, optimizer, criterion)
213
+ torch.nn.utils.clip_grad_norm_(classifier.model.parameters(), max_norm=1.0)
214
+ scheduler.step()
215
+ train_loss += loss
216
+ train_pbar.set_postfix({'loss': f'{loss:.4f}'})
217
+
218
+ avg_train_loss = train_loss / len(train_loader)
219
+ val_loss, val_acc, val_prec, val_rec, val_f1, _, _ = evaluate_model_full(classifier, val_loader)
220
+
221
+ history['train_loss'].append(round(avg_train_loss, 4))
222
+ history['val_loss'].append(round(val_loss, 4))
223
+ history['val_acc'].append(round(val_acc, 4))
224
+ history['val_f1'].append(round(val_f1, 4))
225
+
226
+ print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")
227
+
228
+ if val_f1 > best_val_f1:
229
+ best_val_f1 = val_f1
230
+ classifier.save_model(best_model_path)
231
+ print(f" [+] Best model saved with F1: {val_f1:.4f}")
232
+
233
+ early_stopping(val_loss, epoch + 1)
234
+ if early_stopping.early_stop:
235
+ print("Stopping early.")
236
+ break
237
+
238
+ # ── Final evaluation on TEST set ────────────────────────────────
239
+ classifier.load_model(best_model_path)
240
+ test_loss, test_acc, test_prec, test_rec, test_f1, all_preds, all_labels = evaluate_model_full(classifier, test_loader)
241
+
242
+ training_duration = round(time.time() - start_time, 2)
243
+
244
+ # Per-class metrics
245
+ per_class_p, per_class_r, per_class_f1, per_class_support = precision_recall_fscore_support(
246
+ all_labels, all_preds, average=None, zero_division=0
247
+ )
248
+ per_class_metrics = {}
249
+ for i, name in enumerate(INTENT_NAMES):
250
+ per_class_metrics[name] = {
251
+ 'precision': round(float(per_class_p[i]), 4),
252
+ 'recall': round(float(per_class_r[i]), 4),
253
+ 'f1_score': round(float(per_class_f1[i]), 4),
254
+ 'support': int(per_class_support[i])
255
+ }
256
+
257
+ # Confusion matrix
258
+ cm = confusion_matrix(all_labels, all_preds).tolist()
259
+
260
+ # Classification report
261
+ cls_report = classification_report(all_labels, all_preds, target_names=INTENT_NAMES, zero_division=0)
262
+
263
+ # ── Save results ───────────────────────────────────────────────
264
+ results = {
265
+ 'model': 'TinyBert-CNN',
266
+ 'hyperparameters': hyperparams,
267
+ 'training_duration_seconds': training_duration,
268
+ 'epochs_trained': len(history['train_loss']),
269
+ 'metrics': {
270
+ 'accuracy': round(test_acc, 4),
271
+ 'f1_score': round(test_f1, 4),
272
+ 'precision': round(test_prec, 4),
273
+ 'recall': round(test_rec, 4),
274
+ 'test_loss': round(test_loss, 4)
275
+ },
276
+ 'per_class_metrics': per_class_metrics,
277
+ 'confusion_matrix': cm,
278
+ 'training_history': history,
279
+ 'classification_report': cls_report
280
+ }
281
+
282
+ with open('training_results.json', 'w') as f:
283
+ json.dump(results, f, indent=4)
284
+
285
+ print(f"\n{'='*60}")
286
+ print(f"TRAINING COMPLETE ({training_duration:.1f}s)")
287
+ print(f"{'='*60}")
288
+ print(f"Test Acc: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test Loss: {test_loss:.4f}")
289
+ print(f"\nPer-class results:")
290
+ print(cls_report)
291
+ print(f"Confusion Matrix:")
292
+ for row in cm:
293
+ print(f" {row}")
294
+ print(f"\n[+] Results saved to 'training_results.json'")
295
+
296
+
297
+ if __name__ == '__main__':
298
+ main()