Spaces:
Paused
Paused
| """ | |
| Advanced early stopping with multi-metric support. | |
| Prevents overfitting by tracking multiple metrics simultaneously. | |
| """ | |
| import numpy as np | |
| from pathlib import Path | |
| import torch | |
| import json | |
| class MultiMetricEarlyStopping: | |
| """ | |
| Early stopping that considers multiple metrics with weighted scores. | |
| Advantages over single-metric stopping: | |
| - Prevents overfitting on one metric while degrading others | |
| - Better general model performance | |
| - More stable convergence | |
| Example metric weights: | |
| {'loss': 0.2, 'accuracy': 0.4, 'bertscore': 0.3, 'f1': 0.1} | |
| """ | |
| def __init__(self, patience=5, metric_weights=None, mode='maximize', | |
| save_dir=None, verbose=True): | |
| """ | |
| Args: | |
| patience: Number of evaluations with no improvement before stopping | |
| metric_weights: Dict of {metric_name: weight}. If None, uses 'loss' only | |
| mode: 'maximize' or 'minimize' | |
| save_dir: Directory to save best model | |
| verbose: Print progress | |
| """ | |
| self.patience = patience | |
| self.counter = 0 | |
| self.best_score = None | |
| self.best_metrics = None | |
| self.save_dir = Path(save_dir) if save_dir else None | |
| self.verbose = verbose | |
| self.mode = mode | |
| # Default metric weights if not provided | |
| if metric_weights is None: | |
| self.metric_weights = {'loss': 1.0} | |
| else: | |
| self.metric_weights = metric_weights | |
| # Normalize weights to sum to 1 | |
| total_weight = sum(self.metric_weights.values()) | |
| self.metric_weights = {k: v/total_weight for k, v in self.metric_weights.items()} | |
| self.history = [] | |
| if self.save_dir: | |
| self.save_dir.mkdir(parents=True, exist_ok=True) | |
| def compute_score(self, metrics): | |
| """ | |
| Compute weighted score from multiple metrics. | |
| Args: | |
| metrics: Dict of metric_name -> value | |
| Returns: | |
| Weighted score | |
| """ | |
| score = 0.0 | |
| for metric_name, weight in self.metric_weights.items(): | |
| if metric_name not in metrics: | |
| if self.verbose: | |
| print(f"[WARNING] Metric '{metric_name}' not found in current metrics") | |
| continue | |
| metric_value = metrics[metric_name] | |
| # Handle loss (we want to minimize it) | |
| if 'loss' in metric_name.lower(): | |
| # Invert loss for maximization context | |
| metric_contribution = -metric_value if self.mode == 'maximize' else metric_value | |
| else: | |
| # Most metrics should be maximized (accuracy, F1, etc.) | |
| metric_contribution = metric_value | |
| score += metric_contribution * weight | |
| return score | |
| def __call__(self, metrics, model=None, epoch=None): | |
| """ | |
| Check if should stop training. | |
| Args: | |
| metrics: Dict of metric_name -> value | |
| model: Model to save if best | |
| epoch: Current epoch number | |
| Returns: | |
| True if should stop, False otherwise | |
| """ | |
| score = self.compute_score(metrics) | |
| # Store history | |
| self.history.append({ | |
| 'epoch': epoch, | |
| 'score': score, | |
| 'metrics': metrics.copy() | |
| }) | |
| if self.best_score is None: | |
| self.best_score = score | |
| self.best_metrics = metrics.copy() | |
| if model is not None and self.save_dir: | |
| self._save_checkpoint(model, epoch, metrics) | |
| elif score > self.best_score: | |
| self.best_score = score | |
| self.best_metrics = metrics.copy() | |
| self.counter = 0 | |
| if model is not None and self.save_dir: | |
| self._save_checkpoint(model, epoch, metrics) | |
| if self.verbose: | |
| print(f"✓ Epoch {epoch}: New best score {score:.4f}") | |
| else: | |
| self.counter += 1 | |
| if self.verbose: | |
| print(f"✗ Epoch {epoch}: No improvement ({self.counter}/{self.patience})") | |
| # Check if should stop | |
| if self.counter >= self.patience: | |
| if self.verbose: | |
| print(f"\n[EARLY STOPPING] Patience exceeded. Best metrics:") | |
| for k, v in self.best_metrics.items(): | |
| if isinstance(v, float): | |
| print(f" {k}: {v:.4f}") | |
| return True | |
| return False | |
| def _save_checkpoint(self, model, epoch, metrics): | |
| """Save best model checkpoint.""" | |
| if self.save_dir is None: | |
| return | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'metrics': metrics | |
| } | |
| save_path = self.save_dir / f"best_checkpoint_epoch_{epoch}.pt" | |
| torch.save(checkpoint, save_path) | |
| # Also save metrics record | |
| metrics_path = self.save_dir / f"best_metrics_epoch_{epoch}.json" | |
| with open(metrics_path, 'w') as f: | |
| json.dump(metrics, f, indent=2, default=str) | |
| if self.verbose: | |
| print(f" 💾 Saved checkpoint to {save_path}") | |
| def get_best_metrics(self): | |
| """Return best metrics found during training.""" | |
| return self.best_metrics | |
| def get_history(self): | |
| """Return training history.""" | |
| return self.history | |
| def plot_metrics(self, save_path=None): | |
| """ | |
| Plot metric progression during training. | |
| Args: | |
| save_path: Path to save figure | |
| """ | |
| try: | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| print("[WARNING] matplotlib not installed, cannot plot") | |
| return | |
| if not self.history: | |
| print("[WARNING] No history to plot") | |
| return | |
| epochs = [h['epoch'] for h in self.history] | |
| scores = [h['score'] for h in self.history] | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(epochs, scores, 'b-o', label='Composite Score') | |
| plt.axhline(y=self.best_score, color='r', linestyle='--', label=f'Best: {self.best_score:.4f}') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Score') | |
| plt.legend() | |
| plt.title('Early Stopping - Composite Metric Score') | |
| plt.grid(True, alpha=0.3) | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"[INFO] Metric plot saved to {save_path}") | |
| plt.close() | |
| class DynamicClassWeights: | |
| """ | |
| Compute class weights dynamically from training data. | |
| Adapts to actual data distribution. | |
| """ | |
| def compute_weights(dataloader, device='cpu'): | |
| """ | |
| Compute class weights from data distribution. | |
| Args: | |
| dataloader: DataLoader to analyze | |
| device: Device for tensor | |
| Returns: | |
| Tensor of class weights | |
| """ | |
| class_counts = {} | |
| for batch in dataloader: | |
| labels = batch.get('label_closed', None) | |
| if labels is None: | |
| continue | |
| # Count occurrences of each class | |
| unique_labels, counts = torch.unique(labels, return_counts=True) | |
| for label, count in zip(unique_labels, counts): | |
| label_idx = label.item() | |
| if label_idx >= 0: # Ignore negative indices | |
| class_counts[label_idx] = class_counts.get(label_idx, 0) + count.item() | |
| if not class_counts: | |
| # Default weights if no data found | |
| return torch.ones(2, device=device) | |
| # Compute inverse frequency weights | |
| total_samples = sum(class_counts.values()) | |
| num_classes = len(class_counts) | |
| weights = torch.zeros(max(class_counts.keys()) + 1, device=device) | |
| for class_idx, count in class_counts.items(): | |
| # Weight = total / (num_classes * count) - higher weight for rarer classes | |
| weight = total_samples / (num_classes * max(count, 1)) | |
| weights[class_idx] = weight | |
| # Normalize to sum to num_classes | |
| weights = weights / weights.sum() * num_classes | |
| print("[INFO] Dynamic Class Weights:") | |
| for class_idx in sorted(class_counts.keys()): | |
| print(f" Class {class_idx}: Weight={weights[class_idx]:.4f}, Samples={class_counts[class_idx]}") | |
| return weights.to(device) | |