| """ |
| Gradient Clipping Experiment |
| |
| This script demonstrates how gradient clipping stabilizes training by preventing |
| sudden large weight updates caused by rare, high-loss data points. |
| |
| Experiment Setup: |
| - Simple model: Embedding(4, 16) -> Linear(16, 4) |
| - Vocabulary: ['A', 'B', 'C', 'D'] |
| - Dataset: 1000 samples with imbalanced targets (990 'A', 10 'B') |
| - Compare training with and without gradient clipping |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import random |
|
|
| |
| SEED = 42 |
|
|
|
|
| def set_seeds(seed=SEED): |
| """Set all random seeds for reproducibility.""" |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
|
|
| |
| |
| |
|
|
| class SimpleNextTokenModel(nn.Module): |
| """ |
| Simple model that takes a token index and predicts the next token. |
| Architecture: Embedding -> Linear |
| """ |
| def __init__(self, vocab_size=4, embedding_dim=16): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| self.linear = nn.Linear(embedding_dim, vocab_size) |
| |
| def forward(self, x): |
| """ |
| Args: |
| x: Token indices of shape (batch_size,) |
| Returns: |
| Logits of shape (batch_size, vocab_size) |
| """ |
| embedded = self.embedding(x) |
| logits = self.linear(embedded) |
| return logits |
|
|
|
|
| |
| |
| |
|
|
| def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED): |
| """ |
| Create a synthetic dataset with imbalanced targets. |
| |
| Args: |
| n_samples: Total number of samples |
| n_rare: Number of rare 'B' samples |
| seed: Random seed for reproducibility |
| |
| Returns: |
| inputs: Random token indices (0-3) |
| targets: 990 'A' (0) and 10 'B' (1) |
| rare_indices: Indices where target is 'B' |
| """ |
| |
| set_seeds(seed) |
| |
| vocab = {'A': 0, 'B': 1, 'C': 2, 'D': 3} |
| |
| |
| inputs = torch.randint(0, 4, (n_samples,)) |
| |
| |
| targets = torch.zeros(n_samples, dtype=torch.long) |
| |
| |
| rare_indices = random.sample(range(n_samples), n_rare) |
| targets[rare_indices] = 1 |
| |
| return inputs, targets, sorted(rare_indices) |
|
|
|
|
| |
| |
| |
|
|
| def compute_weight_norm(model): |
| """Compute L2 norm of all model weights.""" |
| total_norm = 0.0 |
| for param in model.parameters(): |
| total_norm += param.data.norm(2).item() ** 2 |
| return total_norm ** 0.5 |
|
|
|
|
| def get_initial_weights(seed=SEED): |
| """Get initial weights for reproducible model initialization.""" |
| set_seeds(seed) |
| model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
| return {name: param.clone() for name, param in model.state_dict().items()} |
|
|
|
|
| def train_epoch(model, optimizer, criterion, inputs, targets, clip_grad=False, max_norm=1.0): |
| """ |
| Train for one epoch, recording metrics at each step. |
| |
| Args: |
| model: The neural network |
| optimizer: SGD optimizer |
| criterion: CrossEntropyLoss |
| inputs: Input token indices |
| targets: Target token indices |
| clip_grad: Whether to apply gradient clipping |
| max_norm: Maximum gradient norm (if clipping) |
| |
| Returns: |
| losses: List of losses per step |
| grad_norms: List of gradient norms per step (before clipping) |
| weight_norms: List of weight norms per step |
| """ |
| model.train() |
| |
| losses = [] |
| grad_norms = [] |
| weight_norms = [] |
| |
| |
| for i in range(len(inputs)): |
| x = inputs[i:i+1] |
| y = targets[i:i+1] |
| |
| optimizer.zero_grad() |
| |
| |
| logits = model(x) |
| loss = criterion(logits, y) |
| |
| |
| loss.backward() |
| |
| |
| |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) |
| |
| |
| if clip_grad: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
| |
| |
| optimizer.step() |
| |
| |
| losses.append(loss.item()) |
| grad_norms.append(grad_norm.item()) |
| weight_norms.append(compute_weight_norm(model)) |
| |
| return losses, grad_norms, weight_norms |
|
|
|
|
| |
| |
| |
|
|
| def run_training(inputs, targets, rare_indices, clip_grad=False, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=None): |
| """ |
| Run complete training loop. |
| |
| Args: |
| inputs: Input token indices |
| targets: Target token indices |
| rare_indices: Indices of rare 'B' samples |
| clip_grad: Whether to apply gradient clipping |
| max_norm: Maximum gradient norm threshold |
| n_epochs: Number of training epochs |
| lr: Learning rate |
| init_weights: Initial model weights for reproducibility |
| |
| Returns: |
| all_losses, all_grad_norms, all_weight_norms: Metrics across all steps |
| """ |
| |
| set_seeds(SEED) |
| model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
| if init_weights: |
| model.load_state_dict(init_weights) |
| |
| optimizer = optim.SGD(model.parameters(), lr=lr) |
| criterion = nn.CrossEntropyLoss() |
| |
| all_losses = [] |
| all_grad_norms = [] |
| all_weight_norms = [] |
| |
| mode = "WITH" if clip_grad else "WITHOUT" |
| print(f"\n{'='*60}") |
| print(f"Training {mode} gradient clipping (max_norm={max_norm})") |
| print(f"{'='*60}") |
| |
| for epoch in range(n_epochs): |
| losses, grad_norms, weight_norms = train_epoch( |
| model, optimizer, criterion, inputs, targets, |
| clip_grad=clip_grad, max_norm=max_norm |
| ) |
| |
| all_losses.extend(losses) |
| all_grad_norms.extend(grad_norms) |
| all_weight_norms.extend(weight_norms) |
| |
| avg_loss = np.mean(losses) |
| max_grad = np.max(grad_norms) |
| print(f"Epoch {epoch+1}/{n_epochs}: Avg Loss={avg_loss:.4f}, Max Grad Norm={max_grad:.4f}") |
| |
| return all_losses, all_grad_norms, all_weight_norms |
|
|
|
|
| |
| |
| |
|
|
| def plot_metrics(losses, grad_norms, weight_norms, title, filename, rare_indices=None, n_samples=1000): |
| """ |
| Plot training metrics: loss, gradient norm, and weight norm. |
| |
| Args: |
| losses: List of losses per step |
| grad_norms: List of gradient norms per step |
| weight_norms: List of weight norms per step |
| title: Plot title |
| filename: Output filename |
| rare_indices: Indices of rare 'B' samples (for highlighting) |
| n_samples: Number of samples per epoch |
| """ |
| fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True) |
| |
| steps = range(len(losses)) |
| n_epochs = len(losses) // n_samples |
| |
| |
| axes[0].plot(steps, losses, 'b-', alpha=0.7, linewidth=0.5) |
| axes[0].set_ylabel('Training Loss', fontsize=12) |
| axes[0].set_title(title, fontsize=14, fontweight='bold') |
| axes[0].grid(True, alpha=0.3) |
| |
| |
| if rare_indices: |
| for epoch in range(n_epochs): |
| for idx in rare_indices: |
| step = epoch * n_samples + idx |
| if step < len(losses): |
| axes[0].axvline(x=step, color='red', alpha=0.3, linewidth=0.5) |
| |
| |
| axes[1].plot(steps, grad_norms, 'g-', alpha=0.7, linewidth=0.5) |
| axes[1].set_ylabel('Gradient L2 Norm', fontsize=12) |
| axes[1].grid(True, alpha=0.3) |
| |
| |
| if "With" in title or "WITH" in title: |
| axes[1].axhline(y=1.0, color='red', linestyle='--', label='Clip threshold (1.0)') |
| axes[1].legend() |
| |
| if rare_indices: |
| for epoch in range(n_epochs): |
| for idx in rare_indices: |
| step = epoch * n_samples + idx |
| if step < len(grad_norms): |
| axes[1].axvline(x=step, color='red', alpha=0.3, linewidth=0.5) |
| |
| |
| axes[2].plot(steps, weight_norms, 'm-', alpha=0.7, linewidth=0.5) |
| axes[2].set_ylabel('Weight L2 Norm', fontsize=12) |
| axes[2].set_xlabel('Training Step', fontsize=12) |
| axes[2].grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| plt.savefig(filename, dpi=150, bbox_inches='tight') |
| plt.close() |
| |
| print(f"Plot saved to: {filename}") |
|
|
|
|
| def plot_comparison(metrics_no_clip, metrics_with_clip, rare_indices, filename, n_samples=1000): |
| """ |
| Create side-by-side comparison plot. |
| |
| Args: |
| metrics_no_clip: (losses, grad_norms, weight_norms) without clipping |
| metrics_with_clip: (losses, grad_norms, weight_norms) with clipping |
| rare_indices: Indices of rare 'B' samples |
| filename: Output filename |
| n_samples: Number of samples per epoch |
| """ |
| fig, axes = plt.subplots(3, 2, figsize=(16, 12)) |
| |
| losses_no, grads_no, weights_no = metrics_no_clip |
| losses_with, grads_with, weights_with = metrics_with_clip |
| |
| steps = range(len(losses_no)) |
| n_epochs = len(losses_no) // n_samples |
| |
| |
| axes[0, 0].plot(steps, losses_no, 'b-', alpha=0.7, linewidth=0.5) |
| axes[0, 0].set_ylabel('Training Loss', fontsize=11) |
| axes[0, 0].set_title('WITHOUT Gradient Clipping', fontsize=13, fontweight='bold', color='red') |
| axes[0, 0].grid(True, alpha=0.3) |
| |
| axes[1, 0].plot(steps, grads_no, 'g-', alpha=0.7, linewidth=0.5) |
| axes[1, 0].set_ylabel('Gradient L2 Norm', fontsize=11) |
| axes[1, 0].grid(True, alpha=0.3) |
| |
| axes[2, 0].plot(steps, weights_no, 'm-', alpha=0.7, linewidth=0.5) |
| axes[2, 0].set_ylabel('Weight L2 Norm', fontsize=11) |
| axes[2, 0].set_xlabel('Training Step', fontsize=11) |
| axes[2, 0].grid(True, alpha=0.3) |
| |
| |
| axes[0, 1].plot(steps, losses_with, 'b-', alpha=0.7, linewidth=0.5) |
| axes[0, 1].set_title('WITH Gradient Clipping (max_norm=1.0)', fontsize=13, fontweight='bold', color='green') |
| axes[0, 1].grid(True, alpha=0.3) |
| |
| axes[1, 1].plot(steps, grads_with, 'g-', alpha=0.7, linewidth=0.5) |
| axes[1, 1].axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Clip threshold') |
| axes[1, 1].legend(loc='upper right') |
| axes[1, 1].grid(True, alpha=0.3) |
| |
| axes[2, 1].plot(steps, weights_with, 'm-', alpha=0.7, linewidth=0.5) |
| axes[2, 1].set_xlabel('Training Step', fontsize=11) |
| axes[2, 1].grid(True, alpha=0.3) |
| |
| |
| for col in range(2): |
| for row in range(3): |
| for epoch in range(n_epochs): |
| for idx in rare_indices: |
| step = epoch * n_samples + idx |
| if step < len(losses_no): |
| axes[row, col].axvline(x=step, color='red', alpha=0.2, linewidth=0.5) |
| |
| |
| axes[0, 0].axvline(x=-100, color='red', alpha=0.5, linewidth=2, label="Rare 'B' samples") |
| axes[0, 0].legend(loc='upper right') |
| |
| |
| fig.suptitle('Effect of Gradient Clipping on Training Stability\n(Red lines indicate rare "B" samples)', |
| fontsize=14, fontweight='bold', y=1.02) |
| |
| plt.tight_layout() |
| plt.savefig(filename, dpi=150, bbox_inches='tight') |
| plt.close() |
| |
| print(f"Comparison plot saved to: {filename}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("="*60) |
| print("GRADIENT CLIPPING EXPERIMENT") |
| print("="*60) |
| print("\nThis experiment demonstrates how gradient clipping stabilizes") |
| print("training by preventing sudden large weight updates caused by") |
| print("rare, high-loss data points.\n") |
| |
| |
| inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED) |
| |
| print(f"Dataset created:") |
| print(f" Total samples: {len(inputs)}") |
| print(f" Target 'A' (0): {(targets == 0).sum().item()}") |
| print(f" Target 'B' (1): {(targets == 1).sum().item()}") |
| print(f" Rare 'B' indices: {rare_indices}") |
| |
| |
| init_weights = get_initial_weights(seed=SEED) |
| |
| |
| losses_no_clip, grads_no_clip, weights_no_clip = run_training( |
| inputs, targets, rare_indices, |
| clip_grad=False, n_epochs=3, lr=0.1, init_weights=init_weights |
| ) |
| |
| |
| losses_with_clip, grads_with_clip, weights_with_clip = run_training( |
| inputs, targets, rare_indices, |
| clip_grad=True, max_norm=1.0, n_epochs=3, lr=0.1, init_weights=init_weights |
| ) |
| |
| |
| print("\n" + "="*60) |
| print("GENERATING PLOTS") |
| print("="*60) |
| |
| plot_metrics( |
| losses_no_clip, grads_no_clip, weights_no_clip, |
| "Training WITHOUT Gradient Clipping", |
| "no_clipping.png", |
| rare_indices |
| ) |
| |
| plot_metrics( |
| losses_with_clip, grads_with_clip, weights_with_clip, |
| "Training WITH Gradient Clipping (max_norm=1.0)", |
| "with_clipping.png", |
| rare_indices |
| ) |
| |
| |
| plot_comparison( |
| (losses_no_clip, grads_no_clip, weights_no_clip), |
| (losses_with_clip, grads_with_clip, weights_with_clip), |
| rare_indices, |
| "comparison.png" |
| ) |
| |
| |
| print("\n" + "="*60) |
| print("SUMMARY STATISTICS") |
| print("="*60) |
| |
| print("\nWithout Gradient Clipping:") |
| print(f" Max Gradient Norm: {max(grads_no_clip):.4f}") |
| print(f" Mean Gradient Norm: {np.mean(grads_no_clip):.4f}") |
| print(f" Std Gradient Norm: {np.std(grads_no_clip):.4f}") |
| print(f" Final Weight Norm: {weights_no_clip[-1]:.4f}") |
| print(f" Final Loss: {losses_no_clip[-1]:.4f}") |
| |
| print("\nWith Gradient Clipping (max_norm=1.0):") |
| print(f" Max Gradient Norm: {max(grads_with_clip):.4f}") |
| print(f" Mean Gradient Norm: {np.mean(grads_with_clip):.4f}") |
| print(f" Std Gradient Norm: {np.std(grads_with_clip):.4f}") |
| print(f" Final Weight Norm: {weights_with_clip[-1]:.4f}") |
| print(f" Final Loss: {losses_with_clip[-1]:.4f}") |
| |
| |
| return { |
| 'no_clip': { |
| 'max_grad': max(grads_no_clip), |
| 'mean_grad': np.mean(grads_no_clip), |
| 'std_grad': np.std(grads_no_clip), |
| 'final_weight': weights_no_clip[-1], |
| 'final_loss': losses_no_clip[-1] |
| }, |
| 'with_clip': { |
| 'max_grad': max(grads_with_clip), |
| 'mean_grad': np.mean(grads_with_clip), |
| 'std_grad': np.std(grads_with_clip), |
| 'final_weight': weights_with_clip[-1], |
| 'final_loss': losses_with_clip[-1] |
| }, |
| 'rare_indices': rare_indices |
| } |
|
|
|
|
| if __name__ == "__main__": |
| stats = main() |
| print("\n" + "="*60) |
| print("EXPERIMENT COMPLETE!") |
| print("="*60) |
|
|