| """ |
| DKM Training Pipeline |
| |
| Implements the training protocol from Section 4 of the paper: |
| - Start from pre-trained model |
| - Insert DKM layers for weight clustering |
| - Fine-tune with SGD (momentum=0.9, lr=0.008) |
| - No loss function or architecture modifications |
| |
| This demo uses CIFAR-10 with a ResNet model to demonstrate the full pipeline. |
| For ImageNet reproduction, scale up to the full dataset and use 8xV100 GPUs. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torchvision |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader |
| import time |
| import argparse |
| import json |
| import os |
|
|
| from dkm import DKMCompressor, compress_model |
| from dkm.utils import print_compression_summary, count_unique_weights |
|
|
|
|
| def get_cifar10_loaders(batch_size=128, num_workers=2): |
| """Get CIFAR-10 train/test data loaders with standard augmentation.""" |
| transform_train = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), |
| ]) |
| |
| transform_test = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), |
| ]) |
| |
| trainset = torchvision.datasets.CIFAR10( |
| root='./data', train=True, download=True, transform=transform_train |
| ) |
| trainloader = DataLoader( |
| trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers |
| ) |
| |
| testset = torchvision.datasets.CIFAR10( |
| root='./data', train=False, download=True, transform=transform_test |
| ) |
| testloader = DataLoader( |
| testset, batch_size=batch_size, shuffle=False, num_workers=num_workers |
| ) |
| |
| return trainloader, testloader |
|
|
|
|
| def evaluate(model, dataloader, device, criterion=None): |
| """Evaluate model accuracy and optional loss.""" |
| model.eval() |
| correct = 0 |
| total = 0 |
| total_loss = 0.0 |
| n_batches = 0 |
| |
| with torch.no_grad(): |
| for images, labels in dataloader: |
| images, labels = images.to(device), labels.to(device) |
| outputs = model(images) |
| |
| if criterion: |
| loss = criterion(outputs, labels) |
| total_loss += loss.item() |
| n_batches += 1 |
| |
| _, predicted = outputs.max(1) |
| total += labels.size(0) |
| correct += predicted.eq(labels).sum().item() |
| |
| accuracy = 100.0 * correct / total |
| avg_loss = total_loss / max(n_batches, 1) |
| return accuracy, avg_loss |
|
|
|
|
| def train_dkm( |
| model, |
| trainloader, |
| testloader, |
| device, |
| bits=2, |
| dim=1, |
| tau=2e-5, |
| epochs=50, |
| lr=0.008, |
| momentum=0.9, |
| weight_decay=0.0, |
| skip_first_last=True, |
| conv_config=None, |
| fc_config=None, |
| ): |
| """ |
| Train a model with DKM compression. |
| |
| Follows the paper's protocol: |
| - SGD optimizer with momentum=0.9 |
| - Fixed learning rate (no per-layer tuning) |
| - Original loss function (CrossEntropyLoss) |
| - DKM inserted into forward pass |
| """ |
| print(f"\\n{'='*70}") |
| print(f"DKM Compression Training") |
| print(f"{'='*70}") |
| print(f"Bits: {bits}, Dim: {dim}, Tau: {tau}") |
| print(f"Epochs: {epochs}, LR: {lr}, Momentum: {momentum}") |
| print(f"Skip first/last: {skip_first_last}") |
| if conv_config: |
| print(f"Conv config: {conv_config}") |
| if fc_config: |
| print(f"FC config: {fc_config}") |
| |
| |
| model = model.to(device) |
| criterion = nn.CrossEntropyLoss() |
| |
| print("\\nEvaluating baseline (pre-trained) model...") |
| baseline_acc, baseline_loss = evaluate(model, testloader, device, criterion) |
| print(f"Baseline accuracy: {baseline_acc:.2f}%, Loss: {baseline_loss:.4f}") |
| |
| |
| compressor = compress_model( |
| model=model, |
| bits=bits, |
| dim=dim, |
| tau=tau, |
| conv_config=conv_config, |
| fc_config=fc_config, |
| skip_first_last=skip_first_last, |
| ) |
| compressor = compressor.to(device) |
| |
| |
| info = compressor.get_compression_info() |
| print_compression_summary(info) |
| |
| |
| optimizer = optim.SGD( |
| compressor.parameters(), |
| lr=lr, |
| momentum=momentum, |
| weight_decay=weight_decay, |
| ) |
| |
| |
| best_acc = 0.0 |
| history = [] |
| |
| for epoch in range(epochs): |
| compressor.train() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
| epoch_start = time.time() |
| |
| for batch_idx, (images, labels) in enumerate(trainloader): |
| images, labels = images.to(device), labels.to(device) |
| |
| optimizer.zero_grad() |
| outputs = compressor(images) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| |
| running_loss += loss.item() |
| _, predicted = outputs.max(1) |
| total += labels.size(0) |
| correct += predicted.eq(labels).sum().item() |
| |
| train_acc = 100.0 * correct / total |
| train_loss = running_loss / len(trainloader) |
| epoch_time = time.time() - epoch_start |
| |
| |
| test_acc, test_loss = evaluate(compressor, testloader, device, criterion) |
| |
| if test_acc > best_acc: |
| best_acc = test_acc |
| |
| history.append({ |
| "epoch": epoch + 1, |
| "train_loss": train_loss, |
| "train_acc": train_acc, |
| "test_loss": test_loss, |
| "test_acc": test_acc, |
| "best_acc": best_acc, |
| "epoch_time": epoch_time, |
| }) |
| |
| print( |
| f"Epoch [{epoch+1}/{epochs}] " |
| f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | " |
| f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}% | " |
| f"Best: {best_acc:.2f}% | Time: {epoch_time:.1f}s" |
| ) |
| |
| |
| print("\\nSnapping weights to nearest centroids...") |
| compressor.snap_weights() |
| |
| |
| unique_counts = count_unique_weights(model) |
| print("\\nUnique weight values per layer after compression:") |
| for name, count in unique_counts.items(): |
| print(f" {name}: {count} unique values") |
| |
| final_acc, final_loss = evaluate(compressor, testloader, device, criterion) |
| print(f"\\nFinal (snapped) accuracy: {final_acc:.2f}%") |
| print(f"Accuracy drop from baseline: {baseline_acc - final_acc:.2f}%") |
| print(f"Best training accuracy: {best_acc:.2f}%") |
| |
| return compressor, history, info |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="DKM Compression Training") |
| parser.add_argument("--bits", type=int, default=2, help="Number of bits for clustering") |
| parser.add_argument("--dim", type=int, default=1, help="Clustering dimension") |
| parser.add_argument("--tau", type=float, default=2e-5, help="Temperature parameter") |
| parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs") |
| parser.add_argument("--lr", type=float, default=0.008, help="Learning rate") |
| parser.add_argument("--batch-size", type=int, default=128, help="Batch size") |
| parser.add_argument("--device", type=str, default="auto", help="Device (auto/cpu/cuda)") |
| parser.add_argument("--save-path", type=str, default="dkm_compressed.pt") |
| args = parser.parse_args() |
| |
| |
| if args.device == "auto": |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(args.device) |
| print(f"Using device: {device}") |
| |
| |
| print("Loading pre-trained ResNet18...") |
| model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) |
| |
| model.fc = nn.Linear(model.fc.in_features, 10) |
| |
| |
| trainloader, testloader = get_cifar10_loaders(batch_size=args.batch_size) |
| |
| |
| compressor, history, info = train_dkm( |
| model=model, |
| trainloader=trainloader, |
| testloader=testloader, |
| device=device, |
| bits=args.bits, |
| dim=args.dim, |
| tau=args.tau, |
| epochs=args.epochs, |
| lr=args.lr, |
| ) |
| |
| |
| export = compressor.export_compressed() |
| torch.save(export, args.save_path) |
| print(f"\\nCompressed model saved to {args.save_path}") |
| |
| |
| with open("training_history.json", "w") as f: |
| json.dump(history, f, indent=2) |
| print("Training history saved to training_history.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|