""" DKM Training Pipeline Implements the training protocol from Section 4 of the paper: - Start from pre-trained model - Insert DKM layers for weight clustering - Fine-tune with SGD (momentum=0.9, lr=0.008) - No loss function or architecture modifications This demo uses CIFAR-10 with a ResNet model to demonstrate the full pipeline. For ImageNet reproduction, scale up to the full dataset and use 8xV100 GPUs. """ import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import time import argparse import json import os from dkm import DKMCompressor, compress_model from dkm.utils import print_compression_summary, count_unique_weights def get_cifar10_loaders(batch_size=128, num_workers=2): """Get CIFAR-10 train/test data loaders with standard augmentation.""" transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train ) trainloader = DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test ) testloader = DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) return trainloader, testloader def evaluate(model, dataloader, device, criterion=None): """Evaluate model accuracy and optional loss.""" model.eval() correct = 0 total = 0 total_loss = 0.0 n_batches = 0 with torch.no_grad(): for images, labels in dataloader: images, labels = images.to(device), labels.to(device) outputs = model(images) if criterion: loss = criterion(outputs, labels) total_loss += loss.item() n_batches += 1 _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() accuracy = 100.0 * correct / total avg_loss = total_loss / max(n_batches, 1) return accuracy, avg_loss def train_dkm( model, trainloader, testloader, device, bits=2, dim=1, tau=2e-5, epochs=50, lr=0.008, momentum=0.9, weight_decay=0.0, skip_first_last=True, conv_config=None, fc_config=None, ): """ Train a model with DKM compression. Follows the paper's protocol: - SGD optimizer with momentum=0.9 - Fixed learning rate (no per-layer tuning) - Original loss function (CrossEntropyLoss) - DKM inserted into forward pass """ print(f"\\n{'='*70}") print(f"DKM Compression Training") print(f"{'='*70}") print(f"Bits: {bits}, Dim: {dim}, Tau: {tau}") print(f"Epochs: {epochs}, LR: {lr}, Momentum: {momentum}") print(f"Skip first/last: {skip_first_last}") if conv_config: print(f"Conv config: {conv_config}") if fc_config: print(f"FC config: {fc_config}") # Evaluate baseline (pre-trained) accuracy first model = model.to(device) criterion = nn.CrossEntropyLoss() print("\\nEvaluating baseline (pre-trained) model...") baseline_acc, baseline_loss = evaluate(model, testloader, device, criterion) print(f"Baseline accuracy: {baseline_acc:.2f}%, Loss: {baseline_loss:.4f}") # Wrap model with DKM compression compressor = compress_model( model=model, bits=bits, dim=dim, tau=tau, conv_config=conv_config, fc_config=fc_config, skip_first_last=skip_first_last, ) compressor = compressor.to(device) # Print compression info info = compressor.get_compression_info() print_compression_summary(info) # SGD optimizer (paper: momentum=0.9, lr=0.008) optimizer = optim.SGD( compressor.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, ) # Training loop best_acc = 0.0 history = [] for epoch in range(epochs): compressor.train() running_loss = 0.0 correct = 0 total = 0 epoch_start = time.time() for batch_idx, (images, labels) in enumerate(trainloader): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = compressor(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() train_acc = 100.0 * correct / total train_loss = running_loss / len(trainloader) epoch_time = time.time() - epoch_start # Evaluate (with hard assignment for accurate eval) test_acc, test_loss = evaluate(compressor, testloader, device, criterion) if test_acc > best_acc: best_acc = test_acc history.append({ "epoch": epoch + 1, "train_loss": train_loss, "train_acc": train_acc, "test_loss": test_loss, "test_acc": test_acc, "best_acc": best_acc, "epoch_time": epoch_time, }) print( f"Epoch [{epoch+1}/{epochs}] " f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | " f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}% | " f"Best: {best_acc:.2f}% | Time: {epoch_time:.1f}s" ) # Final: snap weights to centroids print("\\nSnapping weights to nearest centroids...") compressor.snap_weights() # Verify unique weights unique_counts = count_unique_weights(model) print("\\nUnique weight values per layer after compression:") for name, count in unique_counts.items(): print(f" {name}: {count} unique values") final_acc, final_loss = evaluate(compressor, testloader, device, criterion) print(f"\\nFinal (snapped) accuracy: {final_acc:.2f}%") print(f"Accuracy drop from baseline: {baseline_acc - final_acc:.2f}%") print(f"Best training accuracy: {best_acc:.2f}%") return compressor, history, info def main(): parser = argparse.ArgumentParser(description="DKM Compression Training") parser.add_argument("--bits", type=int, default=2, help="Number of bits for clustering") parser.add_argument("--dim", type=int, default=1, help="Clustering dimension") parser.add_argument("--tau", type=float, default=2e-5, help="Temperature parameter") parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs") parser.add_argument("--lr", type=float, default=0.008, help="Learning rate") parser.add_argument("--batch-size", type=int, default=128, help="Batch size") parser.add_argument("--device", type=str, default="auto", help="Device (auto/cpu/cuda)") parser.add_argument("--save-path", type=str, default="dkm_compressed.pt") args = parser.parse_args() # Device if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) print(f"Using device: {device}") # Load pre-trained ResNet18 print("Loading pre-trained ResNet18...") model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) # Adapt for CIFAR-10 (10 classes instead of 1000) model.fc = nn.Linear(model.fc.in_features, 10) # Get data trainloader, testloader = get_cifar10_loaders(batch_size=args.batch_size) # Train with DKM compressor, history, info = train_dkm( model=model, trainloader=trainloader, testloader=testloader, device=device, bits=args.bits, dim=args.dim, tau=args.tau, epochs=args.epochs, lr=args.lr, ) # Save export = compressor.export_compressed() torch.save(export, args.save_path) print(f"\\nCompressed model saved to {args.save_path}") # Save history with open("training_history.json", "w") as f: json.dump(history, f, indent=2) print("Training history saved to training_history.json") if __name__ == "__main__": main()