File size: 8,786 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
Advanced early stopping with multi-metric support.
Prevents overfitting by tracking multiple metrics simultaneously.
"""

import numpy as np
from pathlib import Path
import torch
import json


class MultiMetricEarlyStopping:
    """
    Early stopping that considers multiple metrics with weighted scores.
    
    Advantages over single-metric stopping:
    - Prevents overfitting on one metric while degrading others
    - Better general model performance
    - More stable convergence
    
    Example metric weights:
        {'loss': 0.2, 'accuracy': 0.4, 'bertscore': 0.3, 'f1': 0.1}
    """
    
    def __init__(self, patience=5, metric_weights=None, mode='maximize',
                 save_dir=None, verbose=True):
        """
        Args:
            patience: Number of evaluations with no improvement before stopping
            metric_weights: Dict of {metric_name: weight}. If None, uses 'loss' only
            mode: 'maximize' or 'minimize'
            save_dir: Directory to save best model
            verbose: Print progress
        """
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.best_metrics = None
        self.save_dir = Path(save_dir) if save_dir else None
        self.verbose = verbose
        self.mode = mode
        
        # Default metric weights if not provided
        if metric_weights is None:
            self.metric_weights = {'loss': 1.0}
        else:
            self.metric_weights = metric_weights
            # Normalize weights to sum to 1
            total_weight = sum(self.metric_weights.values())
            self.metric_weights = {k: v/total_weight for k, v in self.metric_weights.items()}
        
        self.history = []
        
        if self.save_dir:
            self.save_dir.mkdir(parents=True, exist_ok=True)
    
    def compute_score(self, metrics):
        """
        Compute weighted score from multiple metrics.
        
        Args:
            metrics: Dict of metric_name -> value
        
        Returns:
            Weighted score
        """
        score = 0.0
        
        for metric_name, weight in self.metric_weights.items():
            if metric_name not in metrics:
                if self.verbose:
                    print(f"[WARNING] Metric '{metric_name}' not found in current metrics")
                continue
            
            metric_value = metrics[metric_name]
            
            # Handle loss (we want to minimize it)
            if 'loss' in metric_name.lower():
                # Invert loss for maximization context
                metric_contribution = -metric_value if self.mode == 'maximize' else metric_value
            else:
                # Most metrics should be maximized (accuracy, F1, etc.)
                metric_contribution = metric_value
            
            score += metric_contribution * weight
        
        return score
    
    def __call__(self, metrics, model=None, epoch=None):
        """
        Check if should stop training.
        
        Args:
            metrics: Dict of metric_name -> value
            model: Model to save if best
            epoch: Current epoch number
        
        Returns:
            True if should stop, False otherwise
        """
        score = self.compute_score(metrics)
        
        # Store history
        self.history.append({
            'epoch': epoch,
            'score': score,
            'metrics': metrics.copy()
        })
        
        if self.best_score is None:
            self.best_score = score
            self.best_metrics = metrics.copy()
            if model is not None and self.save_dir:
                self._save_checkpoint(model, epoch, metrics)
        elif score > self.best_score:
            self.best_score = score
            self.best_metrics = metrics.copy()
            self.counter = 0
            if model is not None and self.save_dir:
                self._save_checkpoint(model, epoch, metrics)
            if self.verbose:
                print(f"✓ Epoch {epoch}: New best score {score:.4f}")
        else:
            self.counter += 1
            if self.verbose:
                print(f"✗ Epoch {epoch}: No improvement ({self.counter}/{self.patience})")
        
        # Check if should stop
        if self.counter >= self.patience:
            if self.verbose:
                print(f"\n[EARLY STOPPING] Patience exceeded. Best metrics:")
                for k, v in self.best_metrics.items():
                    if isinstance(v, float):
                        print(f"  {k}: {v:.4f}")
            return True
        
        return False
    
    def _save_checkpoint(self, model, epoch, metrics):
        """Save best model checkpoint."""
        if self.save_dir is None:
            return
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'metrics': metrics
        }
        
        save_path = self.save_dir / f"best_checkpoint_epoch_{epoch}.pt"
        torch.save(checkpoint, save_path)
        
        # Also save metrics record
        metrics_path = self.save_dir / f"best_metrics_epoch_{epoch}.json"
        with open(metrics_path, 'w') as f:
            json.dump(metrics, f, indent=2, default=str)
        
        if self.verbose:
            print(f"  💾 Saved checkpoint to {save_path}")
    
    def get_best_metrics(self):
        """Return best metrics found during training."""
        return self.best_metrics
    
    def get_history(self):
        """Return training history."""
        return self.history
    
    def plot_metrics(self, save_path=None):
        """
        Plot metric progression during training.
        
        Args:
            save_path: Path to save figure
        """
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            print("[WARNING] matplotlib not installed, cannot plot")
            return
        
        if not self.history:
            print("[WARNING] No history to plot")
            return
        
        epochs = [h['epoch'] for h in self.history]
        scores = [h['score'] for h in self.history]
        
        plt.figure(figsize=(10, 6))
        plt.plot(epochs, scores, 'b-o', label='Composite Score')
        plt.axhline(y=self.best_score, color='r', linestyle='--', label=f'Best: {self.best_score:.4f}')
        plt.xlabel('Epoch')
        plt.ylabel('Score')
        plt.legend()
        plt.title('Early Stopping - Composite Metric Score')
        plt.grid(True, alpha=0.3)
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"[INFO] Metric plot saved to {save_path}")
        
        plt.close()


class DynamicClassWeights:
    """
    Compute class weights dynamically from training data.
    Adapts to actual data distribution.
    """
    
    @staticmethod
    def compute_weights(dataloader, device='cpu'):
        """
        Compute class weights from data distribution.
        
        Args:
            dataloader: DataLoader to analyze
            device: Device for tensor
        
        Returns:
            Tensor of class weights
        """
        class_counts = {}
        
        for batch in dataloader:
            labels = batch.get('label_closed', None)
            if labels is None:
                continue
            
            # Count occurrences of each class
            unique_labels, counts = torch.unique(labels, return_counts=True)
            for label, count in zip(unique_labels, counts):
                label_idx = label.item()
                if label_idx >= 0:  # Ignore negative indices
                    class_counts[label_idx] = class_counts.get(label_idx, 0) + count.item()
        
        if not class_counts:
            # Default weights if no data found
            return torch.ones(2, device=device)
        
        # Compute inverse frequency weights
        total_samples = sum(class_counts.values())
        num_classes = len(class_counts)
        
        weights = torch.zeros(max(class_counts.keys()) + 1, device=device)
        for class_idx, count in class_counts.items():
            # Weight = total / (num_classes * count) - higher weight for rarer classes
            weight = total_samples / (num_classes * max(count, 1))
            weights[class_idx] = weight
        
        # Normalize to sum to num_classes
        weights = weights / weights.sum() * num_classes
        
        print("[INFO] Dynamic Class Weights:")
        for class_idx in sorted(class_counts.keys()):
            print(f"  Class {class_idx}: Weight={weights[class_idx]:.4f}, Samples={class_counts[class_idx]}")
        
        return weights.to(device)