| """ |
| Extended Gradient Clipping Experiment V2: Testing Physics-of-AI Predictions |
| |
| Key changes from V1: |
| 1. More epochs (10 instead of 3) to allow rare class learning |
| 2. Smaller learning rate (0.01) for more stable training |
| 3. More frequent tracking to catch dynamics |
| 4. Added loss tracking per class to understand learning dynamics |
| |
| Predictions being tested: |
| - Prediction 2: Representation Collapse (effective dimensionality drops without clipping) |
| - Prediction 4: Rare Sample Learning (clipping improves rare class accuracy) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import random |
| from typing import Dict, List, Tuple |
|
|
| SEED = 42 |
|
|
|
|
| def set_seeds(seed=SEED): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
|
|
| class SimpleNextTokenModel(nn.Module): |
| def __init__(self, vocab_size=4, embedding_dim=16): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| self.linear = nn.Linear(embedding_dim, vocab_size) |
| |
| def forward(self, x): |
| embedded = self.embedding(x) |
| logits = self.linear(embedded) |
| return logits |
| |
| def get_embeddings(self): |
| return self.embedding.weight.data.clone() |
|
|
|
|
| def compute_effective_dimension(embedding_matrix: torch.Tensor) -> float: |
| """PCA-based effective dimensionality using entropy.""" |
| centered = embedding_matrix - embedding_matrix.mean(dim=0, keepdim=True) |
| cov = torch.mm(centered.T, centered) / (embedding_matrix.shape[0] - 1) |
| eigenvalues = torch.linalg.eigvalsh(cov) |
| eigenvalues = torch.clamp(eigenvalues, min=1e-10) |
| eigenvalues = eigenvalues / eigenvalues.sum() |
| entropy = -torch.sum(eigenvalues * torch.log(eigenvalues)) |
| return torch.exp(entropy).item() |
|
|
|
|
| def compute_per_class_accuracy(model: nn.Module, inputs: torch.Tensor, |
| targets: torch.Tensor) -> Dict[int, float]: |
| """Compute accuracy for each target class.""" |
| model.eval() |
| with torch.no_grad(): |
| logits = model(inputs) |
| predictions = logits.argmax(dim=1) |
| |
| accuracies = {} |
| for class_idx in range(4): |
| mask = targets == class_idx |
| if mask.sum() > 0: |
| correct = (predictions[mask] == targets[mask]).float().mean().item() |
| accuracies[class_idx] = correct |
| else: |
| accuracies[class_idx] = None |
| |
| return accuracies |
|
|
|
|
| def compute_per_class_loss(model: nn.Module, inputs: torch.Tensor, |
| targets: torch.Tensor, criterion: nn.Module) -> Dict[int, float]: |
| """Compute average loss for each target class.""" |
| model.eval() |
| losses = {} |
| with torch.no_grad(): |
| logits = model(inputs) |
| for class_idx in range(4): |
| mask = targets == class_idx |
| if mask.sum() > 0: |
| class_loss = criterion(logits[mask], targets[mask]).item() |
| losses[class_idx] = class_loss |
| else: |
| losses[class_idx] = None |
| return losses |
|
|
|
|
| def create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED): |
| set_seeds(seed) |
| inputs = torch.randint(0, 4, (n_samples,)) |
| targets = torch.zeros(n_samples, dtype=torch.long) |
| rare_indices = random.sample(range(n_samples), n_rare) |
| targets[rare_indices] = 1 |
| return inputs, targets, sorted(rare_indices) |
|
|
|
|
| def train_with_tracking(inputs: torch.Tensor, targets: torch.Tensor, |
| rare_indices: List[int], clip_grad: bool = False, |
| max_norm: float = 1.0, n_epochs: int = 10, |
| lr: float = 0.01, init_weights=None, |
| track_every: int = 50) -> Dict: |
| """ |
| Extended training with comprehensive tracking. |
| """ |
| set_seeds(SEED) |
| model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
| if init_weights: |
| model.load_state_dict({k: v.clone() for k, v in init_weights.items()}) |
| |
| optimizer = optim.SGD(model.parameters(), lr=lr) |
| criterion = nn.CrossEntropyLoss() |
| |
| metrics = { |
| 'losses': [], |
| 'grad_norms': [], |
| 'weight_norms': [], |
| 'effective_dims': [], |
| 'effective_dim_steps': [], |
| 'class_accuracies': {0: [], 1: [], 2: [], 3: []}, |
| 'class_losses': {0: [], 1: [], 2: [], 3: []}, |
| 'accuracy_steps': [], |
| 'rare_sample_losses': [], |
| 'rare_sample_steps': [], |
| } |
| |
| mode = "WITH" if clip_grad else "WITHOUT" |
| print(f"\n{'='*60}") |
| print(f"Training {mode} gradient clipping (max_norm={max_norm})") |
| print(f"Learning rate: {lr}, Epochs: {n_epochs}") |
| print(f"{'='*60}") |
| |
| step = 0 |
| n_samples = len(inputs) |
| |
| for epoch in range(n_epochs): |
| model.train() |
| epoch_losses = [] |
| |
| for i in range(n_samples): |
| x = inputs[i:i+1] |
| y = targets[i:i+1] |
| |
| optimizer.zero_grad() |
| logits = model(x) |
| loss = criterion(logits, y) |
| loss.backward() |
| |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')) |
| |
| if clip_grad: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
| |
| optimizer.step() |
| |
| metrics['losses'].append(loss.item()) |
| metrics['grad_norms'].append(grad_norm.item()) |
| |
| total_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5 |
| metrics['weight_norms'].append(total_norm) |
| |
| epoch_losses.append(loss.item()) |
| |
| |
| if i in rare_indices: |
| metrics['rare_sample_losses'].append(loss.item()) |
| metrics['rare_sample_steps'].append(step) |
| |
| |
| if step % track_every == 0: |
| emb_matrix = model.get_embeddings() |
| eff_dim = compute_effective_dimension(emb_matrix) |
| |
| metrics['effective_dims'].append(eff_dim) |
| metrics['effective_dim_steps'].append(step) |
| |
| class_acc = compute_per_class_accuracy(model, inputs, targets) |
| class_loss = compute_per_class_loss(model, inputs, targets, criterion) |
| |
| for cls_idx in range(4): |
| if class_acc[cls_idx] is not None: |
| metrics['class_accuracies'][cls_idx].append(class_acc[cls_idx]) |
| else: |
| metrics['class_accuracies'][cls_idx].append(0.0) |
| |
| if class_loss[cls_idx] is not None: |
| metrics['class_losses'][cls_idx].append(class_loss[cls_idx]) |
| else: |
| metrics['class_losses'][cls_idx].append(0.0) |
| |
| metrics['accuracy_steps'].append(step) |
| |
| step += 1 |
| |
| avg_loss = np.mean(epoch_losses) |
| class_acc = compute_per_class_accuracy(model, inputs, targets) |
| class_loss = compute_per_class_loss(model, inputs, targets, criterion) |
| eff_dim = compute_effective_dimension(model.get_embeddings()) |
| |
| b_acc = f"{class_acc[1]:.3f}" if class_acc[1] is not None else "N/A" |
| b_loss = f"{class_loss[1]:.3f}" if class_loss[1] is not None else "N/A" |
| |
| print(f"Epoch {epoch+1:2d}/{n_epochs}: Loss={avg_loss:.4f} | " |
| f"Acc A={class_acc[0]:.3f} B={b_acc} | " |
| f"Loss A={class_loss[0]:.3f} B={b_loss} | " |
| f"EffDim={eff_dim:.3f}") |
| |
| return metrics |
|
|
|
|
| def plot_comprehensive_analysis(metrics_no_clip: Dict, metrics_with_clip: Dict, |
| rare_indices: List[int], filename: str, |
| n_samples: int = 1000): |
| """Create comprehensive 8-panel analysis.""" |
| fig = plt.figure(figsize=(20, 16)) |
| gs = fig.add_gridspec(4, 2, hspace=0.35, wspace=0.25) |
| |
| n_epochs = len(metrics_no_clip['losses']) // n_samples |
| |
| |
| ax1 = fig.add_subplot(gs[0, 0]) |
| ax2 = fig.add_subplot(gs[0, 1]) |
| |
| ax1.plot(metrics_no_clip['effective_dim_steps'], metrics_no_clip['effective_dims'], |
| 'b-', linewidth=2, marker='o', markersize=3) |
| ax1.set_ylabel('Effective Dimension', fontsize=11) |
| ax1.set_title('Effective Dim - WITHOUT Clipping', fontsize=12, fontweight='bold', color='red') |
| ax1.grid(True, alpha=0.3) |
| ax1.set_ylim([2.0, 3.5]) |
| |
| ax2.plot(metrics_with_clip['effective_dim_steps'], metrics_with_clip['effective_dims'], |
| 'g-', linewidth=2, marker='o', markersize=3) |
| ax2.set_title('Effective Dim - WITH Clipping', fontsize=12, fontweight='bold', color='green') |
| ax2.grid(True, alpha=0.3) |
| ax2.set_ylim([2.0, 3.5]) |
| |
| |
| ax3 = fig.add_subplot(gs[1, 0]) |
| ax4 = fig.add_subplot(gs[1, 1]) |
| |
| ax3.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][0], |
| 'r-', linewidth=2, alpha=0.7, label='Without Clip') |
| ax3.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][0], |
| 'g-', linewidth=2, alpha=0.7, label='With Clip') |
| ax3.set_ylabel('Accuracy', fontsize=11) |
| ax3.set_title("Common Class 'A' Accuracy", fontsize=12, fontweight='bold') |
| ax3.legend() |
| ax3.grid(True, alpha=0.3) |
| ax3.set_ylim([0, 1.05]) |
| |
| ax4.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_accuracies'][1], |
| 'r-', linewidth=2, alpha=0.7, label='Without Clip') |
| ax4.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_accuracies'][1], |
| 'g-', linewidth=2, alpha=0.7, label='With Clip') |
| ax4.set_title("Rare Class 'B' Accuracy [KEY PREDICTION]", fontsize=12, fontweight='bold', color='purple') |
| ax4.legend() |
| ax4.grid(True, alpha=0.3) |
| ax4.set_ylim([0, 1.05]) |
| |
| |
| ax5 = fig.add_subplot(gs[2, 0]) |
| ax6 = fig.add_subplot(gs[2, 1]) |
| |
| ax5.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_losses'][0], |
| 'r-', linewidth=2, alpha=0.7, label='Without Clip') |
| ax5.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_losses'][0], |
| 'g-', linewidth=2, alpha=0.7, label='With Clip') |
| ax5.set_ylabel('Loss', fontsize=11) |
| ax5.set_title("Common Class 'A' Loss", fontsize=12, fontweight='bold') |
| ax5.legend() |
| ax5.grid(True, alpha=0.3) |
| |
| ax6.plot(metrics_no_clip['accuracy_steps'], metrics_no_clip['class_losses'][1], |
| 'r-', linewidth=2, alpha=0.7, label='Without Clip') |
| ax6.plot(metrics_with_clip['accuracy_steps'], metrics_with_clip['class_losses'][1], |
| 'g-', linewidth=2, alpha=0.7, label='With Clip') |
| ax6.set_title("Rare Class 'B' Loss", fontsize=12, fontweight='bold') |
| ax6.legend() |
| ax6.grid(True, alpha=0.3) |
| |
| |
| ax7 = fig.add_subplot(gs[3, 0]) |
| ax8 = fig.add_subplot(gs[3, 1]) |
| |
| steps = range(len(metrics_no_clip['grad_norms'])) |
| |
| ax7.plot(steps, metrics_no_clip['grad_norms'], 'r-', alpha=0.3, linewidth=0.5, label='Without Clip') |
| ax7.plot(steps, metrics_with_clip['grad_norms'], 'g-', alpha=0.3, linewidth=0.5, label='With Clip') |
| ax7.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
| ax7.set_ylabel('Gradient Norm', fontsize=11) |
| ax7.set_xlabel('Training Step', fontsize=11) |
| ax7.set_title('Gradient Norms', fontsize=12, fontweight='bold') |
| ax7.legend() |
| ax7.grid(True, alpha=0.3) |
| |
| ax8.plot(steps, metrics_no_clip['weight_norms'], 'r-', alpha=0.7, linewidth=1, label='Without Clip') |
| ax8.plot(steps, metrics_with_clip['weight_norms'], 'g-', alpha=0.7, linewidth=1, label='With Clip') |
| ax8.set_xlabel('Training Step', fontsize=11) |
| ax8.set_title('Weight Norms', fontsize=12, fontweight='bold') |
| ax8.legend() |
| ax8.grid(True, alpha=0.3) |
| |
| fig.suptitle('Extended Gradient Clipping Analysis: Testing Physics-of-AI Predictions\n' |
| f'(10 epochs, lr=0.01, 990 common / 10 rare samples)', |
| fontsize=14, fontweight='bold', y=1.01) |
| |
| plt.savefig(filename, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Comprehensive analysis saved to: {filename}") |
|
|
|
|
| def plot_rare_sample_dynamics(metrics_no_clip: Dict, metrics_with_clip: Dict, |
| filename: str): |
| """Plot dynamics specifically at rare sample positions.""" |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
| |
| |
| ax1 = axes[0, 0] |
| ax1.plot(metrics_no_clip['rare_sample_steps'], metrics_no_clip['rare_sample_losses'], |
| 'ro-', alpha=0.7, markersize=3, linewidth=0.5, label='Without Clip') |
| ax1.plot(metrics_with_clip['rare_sample_steps'], metrics_with_clip['rare_sample_losses'], |
| 'go-', alpha=0.7, markersize=3, linewidth=0.5, label='With Clip') |
| ax1.set_ylabel('Loss at Rare Sample', fontsize=11) |
| ax1.set_title('Loss When Encountering Rare Samples', fontsize=12, fontweight='bold') |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
| |
| |
| ax2 = axes[0, 1] |
| ax2.hist(metrics_no_clip['rare_sample_losses'], bins=30, alpha=0.5, color='red', |
| label=f"Without Clip (mean={np.mean(metrics_no_clip['rare_sample_losses']):.3f})") |
| ax2.hist(metrics_with_clip['rare_sample_losses'], bins=30, alpha=0.5, color='green', |
| label=f"With Clip (mean={np.mean(metrics_with_clip['rare_sample_losses']):.3f})") |
| ax2.set_xlabel('Loss', fontsize=11) |
| ax2.set_ylabel('Count', fontsize=11) |
| ax2.set_title('Distribution of Rare Sample Losses', fontsize=12, fontweight='bold') |
| ax2.legend() |
| ax2.grid(True, alpha=0.3) |
| |
| |
| ax3 = axes[1, 0] |
| |
| |
| n_samples = 1000 |
| n_epochs = len(metrics_no_clip['losses']) // n_samples |
| rare_indices = [25, 104, 114, 142, 228, 250, 281, 654, 754, 759] |
| |
| rare_grad_norms_no = [] |
| rare_grad_norms_with = [] |
| rare_steps = [] |
| |
| for epoch in range(n_epochs): |
| for idx in rare_indices: |
| step = epoch * n_samples + idx |
| if step < len(metrics_no_clip['grad_norms']): |
| rare_grad_norms_no.append(metrics_no_clip['grad_norms'][step]) |
| rare_grad_norms_with.append(metrics_with_clip['grad_norms'][step]) |
| rare_steps.append(step) |
| |
| ax3.scatter(rare_steps, rare_grad_norms_no, c='red', alpha=0.6, s=20, label='Without Clip') |
| ax3.scatter(rare_steps, rare_grad_norms_with, c='green', alpha=0.6, s=20, label='With Clip') |
| ax3.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Clip threshold') |
| ax3.set_xlabel('Training Step', fontsize=11) |
| ax3.set_ylabel('Gradient Norm', fontsize=11) |
| ax3.set_title('Gradient Norms at Rare Sample Positions', fontsize=12, fontweight='bold') |
| ax3.legend() |
| ax3.grid(True, alpha=0.3) |
| |
| |
| ax4 = axes[1, 1] |
| ax4.axis('off') |
| |
| mean_rare_loss_no = np.mean(metrics_no_clip['rare_sample_losses']) |
| mean_rare_loss_with = np.mean(metrics_with_clip['rare_sample_losses']) |
| mean_rare_grad_no = np.mean(rare_grad_norms_no) |
| mean_rare_grad_with = np.mean(rare_grad_norms_with) |
| |
| |
| final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1] if metrics_no_clip['class_accuracies'][1] else 0 |
| final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1] if metrics_with_clip['class_accuracies'][1] else 0 |
| |
| summary_text = f""" |
| RARE SAMPLE DYNAMICS SUMMARY |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| |
| At Rare Sample Positions: |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Mean Loss (WITHOUT Clipping): {mean_rare_loss_no:.4f} |
| Mean Loss (WITH Clipping): {mean_rare_loss_with:.4f} |
| Loss Reduction: {(mean_rare_loss_no - mean_rare_loss_with) / mean_rare_loss_no * 100:+.1f}% |
| |
| Mean Gradient Norm (WITHOUT): {mean_rare_grad_no:.4f} |
| Mean Gradient Norm (WITH): {mean_rare_grad_with:.4f} |
| Gradient Reduction: {(mean_rare_grad_no - mean_rare_grad_with) / mean_rare_grad_no * 100:+.1f}% |
| |
| Final Rare Class Accuracy: |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| WITHOUT Clipping: {final_acc_b_no:.1%} |
| WITH Clipping: {final_acc_b_with:.1%} |
| |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| PHYSICS-OF-AI INTERPRETATION: |
| |
| Gradient clipping acts as a "velocity limiter" in |
| weight space, preventing the model from making |
| sudden large updates when encountering rare samples. |
| |
| This allows the model to gradually learn the rare |
| class pattern rather than overshooting and forgetting. |
| """ |
| |
| ax4.text(0.05, 0.5, summary_text, transform=ax4.transAxes, |
| fontsize=10, verticalalignment='center', fontfamily='monospace', |
| bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.9)) |
| |
| fig.suptitle('Rare Sample Dynamics Analysis\n' |
| '(How the model behaves when encountering rare class B samples)', |
| fontsize=14, fontweight='bold', y=1.01) |
| |
| plt.tight_layout() |
| plt.savefig(filename, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Rare sample dynamics plot saved to: {filename}") |
|
|
|
|
| def main(): |
| print("="*70) |
| print("EXTENDED GRADIENT CLIPPING EXPERIMENT V2") |
| print("Testing Physics-of-AI Predictions with Extended Training") |
| print("="*70) |
| |
| |
| inputs, targets, rare_indices = create_imbalanced_dataset(n_samples=1000, n_rare=10, seed=SEED) |
| |
| print(f"\nDataset: {len(inputs)} samples ({(targets == 0).sum().item()} common, {(targets == 1).sum().item()} rare)") |
| print(f"Rare indices: {rare_indices}") |
| |
| |
| set_seeds(SEED) |
| init_model = SimpleNextTokenModel(vocab_size=4, embedding_dim=16) |
| init_weights = {name: param.clone() for name, param in init_model.state_dict().items()} |
| |
| init_eff_dim = compute_effective_dimension(init_model.get_embeddings()) |
| print(f"Initial effective dimension: {init_eff_dim:.3f}") |
| |
| |
| n_epochs = 10 |
| lr = 0.01 |
| |
| |
| metrics_no_clip = train_with_tracking( |
| inputs, targets, rare_indices, |
| clip_grad=False, n_epochs=n_epochs, lr=lr, |
| init_weights=init_weights, track_every=100 |
| ) |
| |
| |
| metrics_with_clip = train_with_tracking( |
| inputs, targets, rare_indices, |
| clip_grad=True, max_norm=1.0, n_epochs=n_epochs, lr=lr, |
| init_weights=init_weights, track_every=100 |
| ) |
| |
| |
| print("\n" + "="*70) |
| print("GENERATING ANALYSIS PLOTS") |
| print("="*70) |
| |
| plot_comprehensive_analysis( |
| metrics_no_clip, metrics_with_clip, rare_indices, |
| "extended_analysis_v2.png" |
| ) |
| |
| plot_rare_sample_dynamics( |
| metrics_no_clip, metrics_with_clip, |
| "rare_sample_dynamics.png" |
| ) |
| |
| |
| print("\n" + "="*70) |
| print("FINAL PREDICTION TEST RESULTS") |
| print("="*70) |
| |
| |
| dims_no = metrics_no_clip['effective_dims'] |
| dims_with = metrics_with_clip['effective_dims'] |
| |
| print("\n[PREDICTION 2] Representation Collapse:") |
| print(f" Effective Dim Variance (WITHOUT): {np.std(dims_no):.6f}") |
| print(f" Effective Dim Variance (WITH): {np.std(dims_with):.6f}") |
| print(f" Verdict: {'SUPPORTED' if np.std(dims_no) > np.std(dims_with) else 'NOT SUPPORTED'}") |
| |
| |
| final_acc_b_no = metrics_no_clip['class_accuracies'][1][-1] |
| final_acc_b_with = metrics_with_clip['class_accuracies'][1][-1] |
| |
| print("\n[PREDICTION 4] Rare Sample Learning:") |
| print(f" Final Rare Class Accuracy (WITHOUT): {final_acc_b_no:.1%}") |
| print(f" Final Rare Class Accuracy (WITH): {final_acc_b_with:.1%}") |
| print(f" Verdict: {'SUPPORTED' if final_acc_b_with >= final_acc_b_no else 'NOT SUPPORTED'}") |
| |
| return { |
| 'metrics_no_clip': metrics_no_clip, |
| 'metrics_with_clip': metrics_with_clip, |
| 'rare_indices': rare_indices, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| results = main() |
| print("\n" + "="*70) |
| print("EXPERIMENT COMPLETE!") |
| print("="*70) |
|
|