syedmohaiminulhoque commited on
Commit
5e08cea
·
verified ·
1 Parent(s): 8a77fae

Add training script

Browse files
Files changed (1) hide show
  1. train.py +278 -0
train.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DKM Training Pipeline
3
+
4
+ Implements the training protocol from Section 4 of the paper:
5
+ - Start from pre-trained model
6
+ - Insert DKM layers for weight clustering
7
+ - Fine-tune with SGD (momentum=0.9, lr=0.008)
8
+ - No loss function or architecture modifications
9
+
10
+ This demo uses CIFAR-10 with a ResNet model to demonstrate the full pipeline.
11
+ For ImageNet reproduction, scale up to the full dataset and use 8xV100 GPUs.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.optim as optim
17
+ import torchvision
18
+ import torchvision.transforms as transforms
19
+ from torch.utils.data import DataLoader
20
+ import time
21
+ import argparse
22
+ import json
23
+ import os
24
+
25
+ from dkm import DKMCompressor, compress_model
26
+ from dkm.utils import print_compression_summary, count_unique_weights
27
+
28
+
29
+ def get_cifar10_loaders(batch_size=128, num_workers=2):
30
+ """Get CIFAR-10 train/test data loaders with standard augmentation."""
31
+ transform_train = transforms.Compose([
32
+ transforms.RandomCrop(32, padding=4),
33
+ transforms.RandomHorizontalFlip(),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
36
+ ])
37
+
38
+ transform_test = transforms.Compose([
39
+ transforms.ToTensor(),
40
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
41
+ ])
42
+
43
+ trainset = torchvision.datasets.CIFAR10(
44
+ root='./data', train=True, download=True, transform=transform_train
45
+ )
46
+ trainloader = DataLoader(
47
+ trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
48
+ )
49
+
50
+ testset = torchvision.datasets.CIFAR10(
51
+ root='./data', train=False, download=True, transform=transform_test
52
+ )
53
+ testloader = DataLoader(
54
+ testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
55
+ )
56
+
57
+ return trainloader, testloader
58
+
59
+
60
+ def evaluate(model, dataloader, device, criterion=None):
61
+ """Evaluate model accuracy and optional loss."""
62
+ model.eval()
63
+ correct = 0
64
+ total = 0
65
+ total_loss = 0.0
66
+ n_batches = 0
67
+
68
+ with torch.no_grad():
69
+ for images, labels in dataloader:
70
+ images, labels = images.to(device), labels.to(device)
71
+ outputs = model(images)
72
+
73
+ if criterion:
74
+ loss = criterion(outputs, labels)
75
+ total_loss += loss.item()
76
+ n_batches += 1
77
+
78
+ _, predicted = outputs.max(1)
79
+ total += labels.size(0)
80
+ correct += predicted.eq(labels).sum().item()
81
+
82
+ accuracy = 100.0 * correct / total
83
+ avg_loss = total_loss / max(n_batches, 1)
84
+ return accuracy, avg_loss
85
+
86
+
87
+ def train_dkm(
88
+ model,
89
+ trainloader,
90
+ testloader,
91
+ device,
92
+ bits=2,
93
+ dim=1,
94
+ tau=2e-5,
95
+ epochs=50,
96
+ lr=0.008,
97
+ momentum=0.9,
98
+ weight_decay=0.0,
99
+ skip_first_last=True,
100
+ conv_config=None,
101
+ fc_config=None,
102
+ ):
103
+ """
104
+ Train a model with DKM compression.
105
+
106
+ Follows the paper's protocol:
107
+ - SGD optimizer with momentum=0.9
108
+ - Fixed learning rate (no per-layer tuning)
109
+ - Original loss function (CrossEntropyLoss)
110
+ - DKM inserted into forward pass
111
+ """
112
+ print(f"\\n{'='*70}")
113
+ print(f"DKM Compression Training")
114
+ print(f"{'='*70}")
115
+ print(f"Bits: {bits}, Dim: {dim}, Tau: {tau}")
116
+ print(f"Epochs: {epochs}, LR: {lr}, Momentum: {momentum}")
117
+ print(f"Skip first/last: {skip_first_last}")
118
+ if conv_config:
119
+ print(f"Conv config: {conv_config}")
120
+ if fc_config:
121
+ print(f"FC config: {fc_config}")
122
+
123
+ # Evaluate baseline (pre-trained) accuracy first
124
+ model = model.to(device)
125
+ criterion = nn.CrossEntropyLoss()
126
+
127
+ print("\\nEvaluating baseline (pre-trained) model...")
128
+ baseline_acc, baseline_loss = evaluate(model, testloader, device, criterion)
129
+ print(f"Baseline accuracy: {baseline_acc:.2f}%, Loss: {baseline_loss:.4f}")
130
+
131
+ # Wrap model with DKM compression
132
+ compressor = compress_model(
133
+ model=model,
134
+ bits=bits,
135
+ dim=dim,
136
+ tau=tau,
137
+ conv_config=conv_config,
138
+ fc_config=fc_config,
139
+ skip_first_last=skip_first_last,
140
+ )
141
+ compressor = compressor.to(device)
142
+
143
+ # Print compression info
144
+ info = compressor.get_compression_info()
145
+ print_compression_summary(info)
146
+
147
+ # SGD optimizer (paper: momentum=0.9, lr=0.008)
148
+ optimizer = optim.SGD(
149
+ compressor.parameters(),
150
+ lr=lr,
151
+ momentum=momentum,
152
+ weight_decay=weight_decay,
153
+ )
154
+
155
+ # Training loop
156
+ best_acc = 0.0
157
+ history = []
158
+
159
+ for epoch in range(epochs):
160
+ compressor.train()
161
+ running_loss = 0.0
162
+ correct = 0
163
+ total = 0
164
+ epoch_start = time.time()
165
+
166
+ for batch_idx, (images, labels) in enumerate(trainloader):
167
+ images, labels = images.to(device), labels.to(device)
168
+
169
+ optimizer.zero_grad()
170
+ outputs = compressor(images)
171
+ loss = criterion(outputs, labels)
172
+ loss.backward()
173
+ optimizer.step()
174
+
175
+ running_loss += loss.item()
176
+ _, predicted = outputs.max(1)
177
+ total += labels.size(0)
178
+ correct += predicted.eq(labels).sum().item()
179
+
180
+ train_acc = 100.0 * correct / total
181
+ train_loss = running_loss / len(trainloader)
182
+ epoch_time = time.time() - epoch_start
183
+
184
+ # Evaluate (with hard assignment for accurate eval)
185
+ test_acc, test_loss = evaluate(compressor, testloader, device, criterion)
186
+
187
+ if test_acc > best_acc:
188
+ best_acc = test_acc
189
+
190
+ history.append({
191
+ "epoch": epoch + 1,
192
+ "train_loss": train_loss,
193
+ "train_acc": train_acc,
194
+ "test_loss": test_loss,
195
+ "test_acc": test_acc,
196
+ "best_acc": best_acc,
197
+ "epoch_time": epoch_time,
198
+ })
199
+
200
+ print(
201
+ f"Epoch [{epoch+1}/{epochs}] "
202
+ f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
203
+ f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}% | "
204
+ f"Best: {best_acc:.2f}% | Time: {epoch_time:.1f}s"
205
+ )
206
+
207
+ # Final: snap weights to centroids
208
+ print("\\nSnapping weights to nearest centroids...")
209
+ compressor.snap_weights()
210
+
211
+ # Verify unique weights
212
+ unique_counts = count_unique_weights(model)
213
+ print("\\nUnique weight values per layer after compression:")
214
+ for name, count in unique_counts.items():
215
+ print(f" {name}: {count} unique values")
216
+
217
+ final_acc, final_loss = evaluate(compressor, testloader, device, criterion)
218
+ print(f"\\nFinal (snapped) accuracy: {final_acc:.2f}%")
219
+ print(f"Accuracy drop from baseline: {baseline_acc - final_acc:.2f}%")
220
+ print(f"Best training accuracy: {best_acc:.2f}%")
221
+
222
+ return compressor, history, info
223
+
224
+
225
+ def main():
226
+ parser = argparse.ArgumentParser(description="DKM Compression Training")
227
+ parser.add_argument("--bits", type=int, default=2, help="Number of bits for clustering")
228
+ parser.add_argument("--dim", type=int, default=1, help="Clustering dimension")
229
+ parser.add_argument("--tau", type=float, default=2e-5, help="Temperature parameter")
230
+ parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
231
+ parser.add_argument("--lr", type=float, default=0.008, help="Learning rate")
232
+ parser.add_argument("--batch-size", type=int, default=128, help="Batch size")
233
+ parser.add_argument("--device", type=str, default="auto", help="Device (auto/cpu/cuda)")
234
+ parser.add_argument("--save-path", type=str, default="dkm_compressed.pt")
235
+ args = parser.parse_args()
236
+
237
+ # Device
238
+ if args.device == "auto":
239
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
240
+ else:
241
+ device = torch.device(args.device)
242
+ print(f"Using device: {device}")
243
+
244
+ # Load pre-trained ResNet18
245
+ print("Loading pre-trained ResNet18...")
246
+ model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
247
+ # Adapt for CIFAR-10 (10 classes instead of 1000)
248
+ model.fc = nn.Linear(model.fc.in_features, 10)
249
+
250
+ # Get data
251
+ trainloader, testloader = get_cifar10_loaders(batch_size=args.batch_size)
252
+
253
+ # Train with DKM
254
+ compressor, history, info = train_dkm(
255
+ model=model,
256
+ trainloader=trainloader,
257
+ testloader=testloader,
258
+ device=device,
259
+ bits=args.bits,
260
+ dim=args.dim,
261
+ tau=args.tau,
262
+ epochs=args.epochs,
263
+ lr=args.lr,
264
+ )
265
+
266
+ # Save
267
+ export = compressor.export_compressed()
268
+ torch.save(export, args.save_path)
269
+ print(f"\\nCompressed model saved to {args.save_path}")
270
+
271
+ # Save history
272
+ with open("training_history.json", "w") as f:
273
+ json.dump(history, f, indent=2)
274
+ print("Training history saved to training_history.json")
275
+
276
+
277
+ if __name__ == "__main__":
278
+ main()