| """ |
| Final Gradient Clipping Experiment: Testing Physics-of-AI Predictions |
| |
| Key insights from previous experiments: |
| 1. With extreme imbalance (99:1), neither model learns rare class |
| 2. Gradient clipping's benefit is in STABILITY, not learning rare classes per se |
| 3. The key effect is on WEIGHT NORM STABILITY and GRADIENT SPIKE HANDLING |
| |
| This experiment tests: |
| 1. Prediction 2: Representation Collapse - effective dim variance without clipping |
| 2. Prediction 4: Rare Sample Learning - using moderate imbalance (80:20) |
| 3. NEW: Weight norm stability analysis |
| 4. NEW: Gradient spike analysis at rare sample positions |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import random |
| from typing import Dict, List |
|
|
| SEED = 42 |
|
|
|
|
| def set_seeds(seed=SEED): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
|
|
| class SimpleNextTokenModel(nn.Module): |
| def __init__(self, vocab_size=4, embedding_dim=16): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| self.linear = nn.Linear(embedding_dim, vocab_size) |
| |
| def forward(self, x): |
| embedded = self.embedding(x) |
| logits = self.linear(embedded) |
| return logits |
| |
| def get_embeddings(self): |
| return self.embedding.weight.data.clone() |
|
|
|
|
| def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float: |
| """PCA-based effective dimensionality.""" |
| centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True) |
| cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1) |
| eigenvalues = torch.linalg.eigvalsh(cov) |
| eigenvalues = torch.clamp(eigenvalues, min=1e-10) |
| eigenvalues = eigenvalues / eigenvalues.sum() |
| entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) |
| return torch.exp(entropy).item() |
|
|
|
|
| def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor, |
| targets: torch.Tensor) -> Dict[int, float]: |
| model.eval() |
| with torch.no_grad(): |
| logits = model(inputs) |
| predictions = logits.argmax(dim=1) |
| |
| accuracies = {} |
| for class_idx in range(4): |
| mask = targets == class_idx |
| if mask.sum() > 0: |
| correct = (predictions[mask] == targets[mask]).float().mean().item() |
| accuracies[class_idx] = correct |
| else: |
| accuracies[class_idx] = None |
| |
| return accuracies |
|
|
|
|
| def create_dataset_moderate_imbalance(n_samples=1000, rare_ratio=0.2, seed=SEED): |
| """Create dataset with moderate imbalance (e.g., 80:20).""" |
| set_seeds(seed) |
| |
| n_rare = int(n_samples * rare_ratio) |
| n_common = n_samples - n_rare |
| |
| inputs = torch.randint(0, 4, (n_samples,)) |
| targets = torch.zeros(n_samples, dtype=torch.long) |
| |
| rare_indices = random.sample(range(n_samples), n_rare) |
| targets[rare_indices] = 1 |
| |
| return inputs, targets, sorted(rare_indices) |
|
|
|
|
| def create_dataset_extreme_imbalance(n_samples=1000, n_rare=10, seed=SEED): |
| """Create dataset with extreme imbalance (99:1).""" |
| set_seeds(seed) |
| |
| inputs = torch.randint(0, 4, (n_samples,)) |
| targets = torch.zeros(n_samples, dtype=torch.long) |
| |
| rare_indices = random.sample(range(n_samples), n_rare) |
| targets[rare_indices] = 1 |
| |
| return inputs, targets, sorted(rare_indices) |
|
|
|
|
| def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor, |
| rare_indices: List[int], clip_grad: bool = False, |
| max_norm: float = 1.0, n_epochs: int = 10, |
| lr: float = 0.1, init_weights=None, |
| track_every: int = 50) -> Dict: |
| """Training with comprehensive tracking.""" |
| set_seeds(SEED) |
| model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
| if init_weights: |
| model.load_state_dict({k: v.clone() for k, v in init_weights.items()}) |
| |
| optimizer = optim.SGD(model.parameters(), lr=lr) |
| criterion = nn.CrossEntropyLoss() |
| |
| metrics = { |
| 'losses': [], |
| 'grad_norms': [], |
| 'weight_norms': [], |
| 'effective_dims': [], |
| 'effective_dim_steps': [], |
| 'class_accuracies': {0: [], 1: [], 2: [], 3: []}, |
| 'accuracy_steps': [], |
| 'weight_norm_changes': [], |
| } |
| |
| step = 0 |
| n_samples = len(inputs) |
| prev_weight_norm = None |
| |
| for epoch in range(n_epochs): |
| model.train() |
| |
| for i in range(n_samples): |
| x = inputs[i:i+1] |
| y = targets[i:i+1] |
| |
| optimizer.zero_grad() |
| logits = model(x) |
| loss = criterion(logits, y) |
| loss.backward() |
| |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) |
| |
| if clip_grad: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
| |
| optimizer.step() |
| |
| metrics['losses'].append(loss.item()) |
| metrics['grad_norms'].append(grad_norm.item()) |
| |
| current_weight_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
| metrics['weight_norms'].append(current_weight_norm) |
| |
| |
| if prev_weight_norm is not None: |
| metrics['weight_norm_changes'].append(abs(current_weight_norm - prev_weight_norm)) |
| else: |
| metrics['weight_norm_changes'].append(0) |
| prev_weight_norm = current_weight_norm |
| |
| |
| if step % track_every == 0: |
| emb_matrix = model.get_embeddings() |
| eff_dim = compute_effective_dimension(emb_matrix) |
| metrics['effective_dims'].append(eff_dim) |
| metrics['effective_dim_steps'].append(step) |
| |
| class_acc = compute_per_class_accuracy(model, inputs, targets) |
| for cls_idx in range(4): |
| if class_acc[cls_idx] is not None: |
| metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx]) |
| else: |
| metrics['class_accuracies'][cls_idx].append(0.0) |
| metrics['accuracy_steps'].append(step) |
| |
| step += 1 |
| |
| return metrics |
|
|
|
|
| def run_experiment_suite(): |
| """Run complete experiment suite with both imbalance levels.""" |
| print("="*70) |
| print("FINAL GRADIENT CLIPPING EXPERIMENT") |
| print("Testing Physics-of-AI Predictions") |
| print("="*70) |
| |
| |
| set_seeds(SEED) |
| init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
| init_weights = {name: param.clone() for name, param in init_model.state_dict().items()} |
| |
| results = {} |
| |
| |
| |
| |
| print("\n" + "="*70) |
| print("EXPERIMENT 1: EXTREME IMBALANCE (99:1)") |
| print("="*70) |
| |
| inputs_extreme, targets_extreme, rare_extreme = create_dataset_extreme_imbalance( |
| n_samples=1000, n_rare=10, seed=SEED |
| ) |
| print(f"Dataset: {(targets_extreme == 0).sum().item()} common, {(targets_extreme == 1).sum().item()} rare") |
| |
| print("\nTraining WITHOUT clipping...") |
| metrics_extreme_no_clip = train_with_tracking( |
| inputs_extreme, targets_extreme, rare_extreme, |
| clip_grad=False, n_epochs=5, lr=0.1, |
| init_weights=init_weights, track_every=100 |
| ) |
| |
| print("Training WITH clipping...") |
| metrics_extreme_with_clip = train_with_tracking( |
| inputs_extreme, targets_extreme, rare_extreme, |
| clip_grad=True, max_norm=1.0, n_epochs=5, lr=0.1, |
| init_weights=init_weights, track_every=100 |
| ) |
| |
| results['extreme'] = { |
| 'no_clip': metrics_extreme_no_clip, |
| 'with_clip': metrics_extreme_with_clip, |
| 'rare_indices': rare_extreme |
| } |
| |
| |
| |
| |
| print("\n" + "="*70) |
| print("EXPERIMENT 2: MODERATE IMBALANCE (80:20)") |
| print("="*70) |
| |
| inputs_moderate, targets_moderate, rare_moderate = create_dataset_moderate_imbalance( |
| n_samples=1000, rare_ratio=0.2, seed=SEED |
| ) |
| print(f"Dataset: {(targets_moderate == 0).sum().item()} common, {(targets_moderate == 1).sum().item()} rare") |
| |
| print("\nTraining WITHOUT clipping...") |
| metrics_moderate_no_clip = train_with_tracking( |
| inputs_moderate, targets_moderate, rare_moderate, |
| clip_grad=False, n_epochs=10, lr=0.1, |
| init_weights=init_weights, track_every=100 |
| ) |
| |
| print("Training WITH clipping...") |
| metrics_moderate_with_clip = train_with_tracking( |
| inputs_moderate, targets_moderate, rare_moderate, |
| clip_grad=True, max_norm=1.0, n_epochs=10, lr=0.1, |
| init_weights=init_weights, track_every=100 |
| ) |
| |
| results['moderate'] = { |
| 'no_clip': metrics_moderate_no_clip, |
| 'with_clip': metrics_moderate_with_clip, |
| 'rare_indices': rare_moderate |
| } |
| |
| return results |
|
|
|
|
| def plot_final_comparison(results: Dict, filename: str): |
| """Create final comparison plot.""" |
| fig = plt.figure(figsize=(20, 20)) |
| gs = fig.add_gridspec(5, 2, hspace=0.35, wspace=0.25) |
| |
| |
| |
| |
| ax1 = fig.add_subplot(gs[0, 0]) |
| ax2 = fig.add_subplot(gs[0, 1]) |
| |
| |
| steps = range(len(results['extreme']['no_clip']['weight_norms'])) |
| ax1.plot(steps, results['extreme']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') |
| ax1.plot(steps, results['extreme']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') |
| ax1.set_ylabel('Weight Norm', fontsize=11) |
| ax1.set_title('EXTREME (99:1) - Weight Norm Evolution', fontsize=12, fontweight='bold') |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
| |
| |
| steps = range(len(results['moderate']['no_clip']['weight_norms'])) |
| ax2.plot(steps, results['moderate']['no_clip']['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') |
| ax2.plot(steps, results['moderate']['with_clip']['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') |
| ax2.set_title('MODERATE (80:20) - Weight Norm Evolution', fontsize=12, fontweight='bold') |
| ax2.legend() |
| ax2.grid(True, alpha=0.3) |
| |
| |
| |
| |
| ax3 = fig.add_subplot(gs[1, 0]) |
| ax4 = fig.add_subplot(gs[1, 1]) |
| |
| |
| steps = range(len(results['extreme']['no_clip']['weight_norm_changes'])) |
| ax3.plot(steps, results['extreme']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip') |
| ax3.plot(steps, results['extreme']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip') |
| ax3.set_ylabel('|Weight Norm Change|', fontsize=11) |
| ax3.set_title('EXTREME - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold') |
| ax3.legend() |
| ax3.grid(True, alpha=0.3) |
| |
| |
| steps = range(len(results['moderate']['no_clip']['weight_norm_changes'])) |
| ax4.plot(steps, results['moderate']['no_clip']['weight_norm_changes'], 'r-', alpha=0.5, linewidth=0.5, label='Without Clip') |
| ax4.plot(steps, results['moderate']['with_clip']['weight_norm_changes'], 'g-', alpha=0.5, linewidth=0.5, label='With Clip') |
| ax4.set_title('MODERATE - Weight Norm Changes (Stability)', fontsize=12, fontweight='bold') |
| ax4.legend() |
| ax4.grid(True, alpha=0.3) |
| |
| |
| |
| |
| ax5 = fig.add_subplot(gs[2, 0]) |
| ax6 = fig.add_subplot(gs[2, 1]) |
| |
| |
| steps = range(len(results['extreme']['no_clip']['grad_norms'])) |
| ax5.plot(steps, results['extreme']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') |
| ax5.plot(steps, results['extreme']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') |
| ax5.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
| ax5.set_ylabel('Gradient Norm', fontsize=11) |
| ax5.set_title('EXTREME - Gradient Norms', fontsize=12, fontweight='bold') |
| ax5.legend() |
| ax5.grid(True, alpha=0.3) |
| |
| |
| steps = range(len(results['moderate']['no_clip']['grad_norms'])) |
| ax6.plot(steps, results['moderate']['no_clip']['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') |
| ax6.plot(steps, results['moderate']['with_clip']['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') |
| ax6.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
| ax6.set_title('MODERATE - Gradient Norms', fontsize=12, fontweight='bold') |
| ax6.legend() |
| ax6.grid(True, alpha=0.3) |
| |
| |
| |
| |
| ax7 = fig.add_subplot(gs[3, 0]) |
| ax8 = fig.add_subplot(gs[3, 1]) |
| |
| |
| ax7.plot(results['extreme']['no_clip']['effective_dim_steps'], |
| results['extreme']['no_clip']['effective_dims'], |
| 'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip') |
| ax7.plot(results['extreme']['with_clip']['effective_dim_steps'], |
| results['extreme']['with_clip']['effective_dims'], |
| 'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip') |
| ax7.set_ylabel('Effective Dimension', fontsize=11) |
| ax7.set_title('EXTREME - Effective Dimensionality', fontsize=12, fontweight='bold') |
| ax7.legend() |
| ax7.grid(True, alpha=0.3) |
| |
| |
| ax8.plot(results['moderate']['no_clip']['effective_dim_steps'], |
| results['moderate']['no_clip']['effective_dims'], |
| 'r-o', alpha=0.7, linewidth=2, markersize=4, label='Without Clip') |
| ax8.plot(results['moderate']['with_clip']['effective_dim_steps'], |
| results['moderate']['with_clip']['effective_dims'], |
| 'g-o', alpha=0.7, linewidth=2, markersize=4, label='With Clip') |
| ax8.set_title('MODERATE - Effective Dimensionality', fontsize=12, fontweight='bold') |
| ax8.legend() |
| ax8.grid(True, alpha=0.3) |
| |
| |
| |
| |
| ax9 = fig.add_subplot(gs[4, 0]) |
| ax10 = fig.add_subplot(gs[4, 1]) |
| |
| |
| ax9.plot(results['extreme']['no_clip']['accuracy_steps'], |
| results['extreme']['no_clip']['class_accuracies'][1], |
| 'r-', alpha=0.7, linewidth=2, label='Without Clip') |
| ax9.plot(results['extreme']['with_clip']['accuracy_steps'], |
| results['extreme']['with_clip']['class_accuracies'][1], |
| 'g-', alpha=0.7, linewidth=2, label='With Clip') |
| ax9.set_ylabel('Rare Class B Accuracy', fontsize=11) |
| ax9.set_xlabel('Training Step', fontsize=11) |
| ax9.set_title('EXTREME - Rare Class Accuracy', fontsize=12, fontweight='bold') |
| ax9.legend() |
| ax9.grid(True, alpha=0.3) |
| ax9.set_ylim([0, 1.05]) |
| |
| |
| ax10.plot(results['moderate']['no_clip']['accuracy_steps'], |
| results['moderate']['no_clip']['class_accuracies'][1], |
| 'r-', alpha=0.7, linewidth=2, label='Without Clip') |
| ax10.plot(results['moderate']['with_clip']['accuracy_steps'], |
| results['moderate']['with_clip']['class_accuracies'][1], |
| 'g-', alpha=0.7, linewidth=2, label='With Clip') |
| ax10.set_xlabel('Training Step', fontsize=11) |
| ax10.set_title('MODERATE - Rare Class Accuracy', fontsize=12, fontweight='bold') |
| ax10.legend() |
| ax10.grid(True, alpha=0.3) |
| ax10.set_ylim([0, 1.05]) |
| |
| fig.suptitle('Gradient Clipping Analysis: Physics-of-AI Predictions\n' |
| 'Comparing Extreme (99:1) vs Moderate (80:20) Class Imbalance', |
| fontsize=14, fontweight='bold', y=1.01) |
| |
| plt.savefig(filename, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Final comparison plot saved to: {filename}") |
|
|
|
|
| def compute_statistics(results: Dict) -> Dict: |
| """Compute summary statistics for all experiments.""" |
| stats = {} |
| |
| for imbalance in ['extreme', 'moderate']: |
| no_clip = results[imbalance]['no_clip'] |
| with_clip = results[imbalance]['with_clip'] |
| |
| stats[imbalance] = { |
| 'weight_norm_std': { |
| 'no_clip': np.std(no_clip['weight_norms']), |
| 'with_clip': np.std(with_clip['weight_norms']), |
| }, |
| 'weight_change_mean': { |
| 'no_clip': np.mean(no_clip['weight_norm_changes']), |
| 'with_clip': np.mean(with_clip['weight_norm_changes']), |
| }, |
| 'weight_change_max': { |
| 'no_clip': np.max(no_clip['weight_norm_changes']), |
| 'with_clip': np.max(with_clip['weight_norm_changes']), |
| }, |
| 'grad_norm_max': { |
| 'no_clip': np.max(no_clip['grad_norms']), |
| 'with_clip': np.max(with_clip['grad_norms']), |
| }, |
| 'effective_dim_std': { |
| 'no_clip': np.std(no_clip['effective_dims']), |
| 'with_clip': np.std(with_clip['effective_dims']), |
| }, |
| 'final_rare_acc': { |
| 'no_clip': no_clip['class_accuracies'][1][-1] if no_clip['class_accuracies'][1] else 0, |
| 'with_clip': with_clip['class_accuracies'][1][-1] if with_clip['class_accuracies'][1] else 0, |
| }, |
| } |
| |
| return stats |
|
|
|
|
| def print_summary(stats: Dict): |
| """Print formatted summary.""" |
| print("\n" + "="*70) |
| print("EXPERIMENT SUMMARY") |
| print("="*70) |
| |
| for imbalance in ['extreme', 'moderate']: |
| s = stats[imbalance] |
| label = "EXTREME (99:1)" if imbalance == 'extreme' else "MODERATE (80:20)" |
| |
| print(f"\n{label}") |
| print("-" * 50) |
| |
| print(f"\n[PREDICTION 2] Representation Collapse (Effective Dim Variance):") |
| print(f" WITHOUT Clipping: {s['effective_dim_std']['no_clip']:.6f}") |
| print(f" WITH Clipping: {s['effective_dim_std']['with_clip']:.6f}") |
| supported = s['effective_dim_std']['no_clip'] > s['effective_dim_std']['with_clip'] |
| print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}") |
| |
| print(f"\n[PREDICTION 4] Rare Sample Learning:") |
| print(f" Final Rare Accuracy (WITHOUT): {s['final_rare_acc']['no_clip']:.1%}") |
| print(f" Final Rare Accuracy (WITH): {s['final_rare_acc']['with_clip']:.1%}") |
| supported = s['final_rare_acc']['with_clip'] >= s['final_rare_acc']['no_clip'] |
| print(f" Verdict: {'SUPPORTED' if supported else 'NOT SUPPORTED'}") |
| |
| print(f"\n[STABILITY] Weight Norm Analysis:") |
| print(f" Weight Norm Std (WITHOUT): {s['weight_norm_std']['no_clip']:.4f}") |
| print(f" Weight Norm Std (WITH): {s['weight_norm_std']['with_clip']:.4f}") |
| print(f" Max Weight Change (WITHOUT): {s['weight_change_max']['no_clip']:.4f}") |
| print(f" Max Weight Change (WITH): {s['weight_change_max']['with_clip']:.4f}") |
| |
| print(f"\n[GRADIENT] Analysis:") |
| print(f" Max Gradient Norm (WITHOUT): {s['grad_norm_max']['no_clip']:.4f}") |
| print(f" Max Gradient Norm (WITH): {s['grad_norm_max']['with_clip']:.4f}") |
| print(f" Clipping Ratio: {s['grad_norm_max']['no_clip'] / 1.0:.1f}x threshold") |
|
|
|
|
| def main(): |
| |
| results = run_experiment_suite() |
| |
| |
| print("\n" + "="*70) |
| print("GENERATING PLOTS") |
| print("="*70) |
| |
| plot_final_comparison(results, "final_comparison.png") |
| |
| |
| stats = compute_statistics(results) |
| print_summary(stats) |
| |
| return results, stats |
|
|
|
|
| if __name__ == "__main__": |
| results, stats = main() |
| print("\n" + "="*70) |
| print("EXPERIMENT COMPLETE!") |
| print("="*70) |
|
|