| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torchvision import datasets, transforms |
| from torch.utils.data import DataLoader, random_split |
| from torch.utils.tensorboard import SummaryWriter |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import numpy as np |
| import argparse |
| import os |
| import logging |
| from tqdm import tqdm |
| from datetime import datetime |
| import json |
| import random |
| from sklearn.metrics import confusion_matrix, classification_report |
| from pathlib import Path |
|
|
| |
| def setup_logging(log_dir): |
| log_dir = Path(log_dir) |
| log_dir.mkdir(parents=True, exist_ok=True) |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler(log_dir / 'training.log'), |
| logging.StreamHandler() |
| ] |
| ) |
| return logging.getLogger(__name__) |
|
|
| |
| def set_seed(seed=42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| class ConvNet(nn.Module): |
| """Convolutional Neural Network for MNIST""" |
| def __init__(self, dropout_rate=0.3, num_classes=10): |
| super(ConvNet, self).__init__() |
| |
| |
| self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
| self.bn1 = nn.BatchNorm2d(32) |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(64) |
| |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
| self.bn3 = nn.BatchNorm2d(128) |
| self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) |
| self.bn4 = nn.BatchNorm2d(128) |
| |
| self.pool = nn.MaxPool2d(2, 2) |
| self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5) |
| |
| |
| self.fc1 = nn.Linear(128 * 7 * 7, 256) |
| self.bn5 = nn.BatchNorm1d(256) |
| self.dropout1 = nn.Dropout(dropout_rate) |
| |
| self.fc2 = nn.Linear(256, 128) |
| self.bn6 = nn.BatchNorm1d(128) |
| self.dropout2 = nn.Dropout(dropout_rate * 0.5) |
| |
| self.fc3 = nn.Linear(128, num_classes) |
| |
| self._initialize_weights() |
| |
| def _initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, (nn.Conv2d, nn.Linear)): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward(self, x): |
| |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = torch.relu(x) |
| x = self.conv2(x) |
| x = self.bn2(x) |
| x = torch.relu(x) |
| x = self.pool(x) |
| x = self.dropout_conv(x) |
| |
| |
| x = self.conv3(x) |
| x = self.bn3(x) |
| x = torch.relu(x) |
| x = self.conv4(x) |
| x = self.bn4(x) |
| x = torch.relu(x) |
| x = self.pool(x) |
| x = self.dropout_conv(x) |
| |
| |
| x = x.view(x.size(0), -1) |
| |
| |
| x = self.fc1(x) |
| x = self.bn5(x) |
| x = torch.relu(x) |
| x = self.dropout1(x) |
| |
| x = self.fc2(x) |
| x = self.bn6(x) |
| x = torch.relu(x) |
| x = self.dropout2(x) |
| |
| x = self.fc3(x) |
| return x |
|
|
| |
| class ImprovedNN(nn.Module): |
| """Enhanced fully connected network with configurable architecture""" |
| def __init__(self, input_size=784, hidden_sizes=[512, 256, 128], |
| num_classes=10, dropout_rate=0.3): |
| super(ImprovedNN, self).__init__() |
| |
| layers = [] |
| prev_size = input_size |
| |
| for i, hidden_size in enumerate(hidden_sizes): |
| layers.extend([ |
| nn.Linear(prev_size, hidden_size), |
| nn.BatchNorm1d(hidden_size), |
| nn.ReLU(), |
| nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5) |
| ]) |
| prev_size = hidden_size |
| |
| layers.append(nn.Linear(prev_size, num_classes)) |
| self.network = nn.Sequential(*layers) |
| |
| self._initialize_weights() |
| |
| def _initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm1d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| |
| def forward(self, x): |
| x = x.view(x.size(0), -1) |
| return self.network(x) |
|
|
| |
| class Trainer: |
| def __init__(self, model, train_loader, val_loader, test_loader, |
| criterion, optimizer, scheduler, device, args, logger): |
| self.model = model |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.test_loader = test_loader |
| self.criterion = criterion |
| self.optimizer = optimizer |
| self.scheduler = scheduler |
| self.device = device |
| self.args = args |
| self.logger = logger |
| |
| |
| self.writer = SummaryWriter(log_dir=args.log_dir) |
| |
| |
| self.train_losses = [] |
| self.val_losses = [] |
| self.train_accs = [] |
| self.val_accs = [] |
| self.best_val_acc = 0.0 |
| self.patience_counter = 0 |
| |
| |
| self.scaler = torch.cuda.amp.GradScaler() if args.use_amp and device.type == 'cuda' else None |
| |
| def train_epoch(self, epoch): |
| self.model.train() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
| |
| progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1} [Train]") |
| |
| for batch_idx, (images, labels) in enumerate(progress_bar): |
| images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True) |
| |
| self.optimizer.zero_grad(set_to_none=True) |
| |
| |
| if self.scaler: |
| with torch.cuda.amp.autocast(): |
| outputs = self.model(images) |
| loss = self.criterion(outputs, labels) |
| |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| outputs = self.model(images) |
| loss = self.criterion(outputs, labels) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.step() |
| |
| running_loss += loss.item() |
| _, predicted = torch.max(outputs, 1) |
| total += labels.size(0) |
| correct += (predicted == labels).sum().item() |
| |
| |
| global_step = epoch * len(self.train_loader) + batch_idx |
| if batch_idx % 50 == 0: |
| self.writer.add_scalar('Train/BatchLoss', loss.item(), global_step) |
| self.writer.add_scalar('Train/BatchAcc', 100. * correct / total, global_step) |
| |
| progress_bar.set_postfix({ |
| 'Loss': f"{loss.item():.4f}", |
| 'Acc': f"{100.*correct/total:.2f}%" |
| }) |
| |
| epoch_loss = running_loss / len(self.train_loader) |
| epoch_acc = 100. * correct / total |
| |
| return epoch_loss, epoch_acc |
| |
| def validate(self, loader, phase="Val"): |
| self.model.eval() |
| running_loss = 0.0 |
| correct = 0 |
| total = 0 |
| |
| all_preds = [] |
| all_labels = [] |
| |
| with torch.no_grad(): |
| progress_bar = tqdm(loader, desc=f"[{phase}]") |
| for images, labels in progress_bar: |
| images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True) |
| |
| if self.scaler: |
| with torch.cuda.amp.autocast(): |
| outputs = self.model(images) |
| loss = self.criterion(outputs, labels) |
| else: |
| outputs = self.model(images) |
| loss = self.criterion(outputs, labels) |
| |
| running_loss += loss.item() |
| _, predicted = torch.max(outputs, 1) |
| total += labels.size(0) |
| correct += (predicted == labels).sum().item() |
| |
| all_preds.extend(predicted.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
| |
| progress_bar.set_postfix({ |
| 'Loss': f"{loss.item():.4f}", |
| 'Acc': f"{100.*correct/total:.2f}%" |
| }) |
| |
| epoch_loss = running_loss / len(loader) |
| epoch_acc = 100. * correct / total |
| |
| return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels) |
| |
| def train(self): |
| self.logger.info(f"Starting training for {self.args.epochs} epochs") |
| self.logger.info(f"Model: {self.args.model_type}, Optimizer: {self.args.optimizer}") |
| self.logger.info(f"Learning rate: {self.args.lr}, Batch size: {self.args.batch_size}") |
| |
| start_time = datetime.now() |
| |
| for epoch in range(self.args.epochs): |
| |
| if epoch < self.args.warmup_epochs: |
| warmup_lr = self.args.lr * (epoch + 1) / self.args.warmup_epochs |
| for param_group in self.optimizer.param_groups: |
| param_group['lr'] = warmup_lr |
| |
| train_loss, train_acc = self.train_epoch(epoch) |
| val_loss, val_acc, val_preds, val_labels = self.validate(self.val_loader, "Val") |
| |
| self.train_losses.append(train_loss) |
| self.val_losses.append(val_loss) |
| self.train_accs.append(train_acc) |
| self.val_accs.append(val_acc) |
| |
| |
| if epoch >= self.args.warmup_epochs: |
| self.scheduler.step() |
| |
| current_lr = self.optimizer.param_groups[0]['lr'] |
| |
| |
| self.writer.add_scalar('Epoch/TrainLoss', train_loss, epoch) |
| self.writer.add_scalar('Epoch/ValLoss', val_loss, epoch) |
| self.writer.add_scalar('Epoch/TrainAcc', train_acc, epoch) |
| self.writer.add_scalar('Epoch/ValAcc', val_acc, epoch) |
| self.writer.add_scalar('Epoch/LearningRate', current_lr, epoch) |
| |
| |
| per_class_acc = self._compute_per_class_accuracy(val_preds, val_labels) |
| for class_idx, acc in enumerate(per_class_acc): |
| self.writer.add_scalar(f'PerClass/Val_Class_{class_idx}', acc, epoch) |
| |
| self.logger.info(f"Epoch {epoch+1}/{self.args.epochs} | LR: {current_lr:.6f}") |
| self.logger.info(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%") |
| self.logger.info(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%") |
| self.logger.info(f"Per-class Val Acc: {[f'{acc:.1f}%' for acc in per_class_acc]}") |
| |
| |
| if val_acc > self.best_val_acc: |
| self.best_val_acc = val_acc |
| self.patience_counter = 0 |
| self.save_checkpoint(epoch, val_acc, val_loss, train_acc, train_loss, is_best=True) |
| self.logger.info(f"✓ New best model saved! Val Acc: {val_acc:.2f}%") |
| else: |
| self.patience_counter += 1 |
| self.logger.info(f"No improvement. Patience: {self.patience_counter}/{self.args.early_stop_patience}") |
| |
| |
| if (epoch + 1) % self.args.save_freq == 0: |
| self.save_checkpoint(epoch, val_acc, val_loss, train_acc, train_loss, is_best=False) |
| |
| |
| if self.patience_counter >= self.args.early_stop_patience: |
| self.logger.info(f"Early stopping triggered after {epoch+1} epochs") |
| break |
| |
| print("-" * 70) |
| |
| training_time = datetime.now() - start_time |
| self.logger.info(f"Training complete! Time: {training_time}") |
| self.logger.info(f"Best Val Acc: {self.best_val_acc:.2f}%") |
| |
| |
| self.save_training_history() |
| |
| return self.best_val_acc |
| |
| def _compute_per_class_accuracy(self, preds, labels): |
| per_class_acc = [] |
| for class_idx in range(10): |
| mask = labels == class_idx |
| if mask.sum() > 0: |
| class_acc = 100. * (preds[mask] == labels[mask]).sum() / mask.sum() |
| per_class_acc.append(class_acc) |
| else: |
| per_class_acc.append(0.0) |
| return per_class_acc |
| |
| def save_checkpoint(self, epoch, val_acc, val_loss, train_acc, train_loss, is_best=False): |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': self.model.state_dict(), |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'scheduler_state_dict': self.scheduler.state_dict(), |
| 'val_acc': val_acc, |
| 'val_loss': val_loss, |
| 'train_acc': train_acc, |
| 'train_loss': train_loss, |
| 'best_val_acc': self.best_val_acc, |
| 'args': vars(self.args) |
| } |
| |
| if is_best: |
| path = Path(self.args.save_dir) / 'best_model.pth' |
| else: |
| path = Path(self.args.save_dir) / f'checkpoint_epoch_{epoch+1}.pth' |
| |
| torch.save(checkpoint, path) |
| |
| def save_training_history(self): |
| history = { |
| 'train_losses': self.train_losses, |
| 'val_losses': self.val_losses, |
| 'train_accs': self.train_accs, |
| 'val_accs': self.val_accs, |
| 'best_val_acc': self.best_val_acc |
| } |
| |
| path = Path(self.args.save_dir) / 'training_history.json' |
| with open(path, 'w') as f: |
| json.dump(history, f, indent=4) |
| |
| self.logger.info(f"Training history saved to {path}") |
|
|
| |
| def plot_training_curves(history_path, save_path): |
| with open(history_path, 'r') as f: |
| history = json.load(f) |
| |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) |
| |
| epochs_range = range(1, len(history['train_losses']) + 1) |
| |
| ax1.plot(epochs_range, history['train_losses'], 'b-', label='Train Loss', linewidth=2) |
| ax1.plot(epochs_range, history['val_losses'], 'r-', label='Val Loss', linewidth=2) |
| ax1.set_xlabel('Epoch', fontsize=12) |
| ax1.set_ylabel('Loss', fontsize=12) |
| ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold') |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
| |
| ax2.plot(epochs_range, history['train_accs'], 'b-', label='Train Acc', linewidth=2) |
| ax2.plot(epochs_range, history['val_accs'], 'r-', label='Val Acc', linewidth=2) |
| ax2.set_xlabel('Epoch', fontsize=12) |
| ax2.set_ylabel('Accuracy (%)', fontsize=12) |
| ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold') |
| ax2.legend() |
| ax2.grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150) |
| plt.close() |
|
|
| def plot_confusion_matrix(y_true, y_pred, save_path): |
| cm = confusion_matrix(y_true, y_pred) |
| |
| plt.figure(figsize=(10, 8)) |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
| xticklabels=range(10), yticklabels=range(10)) |
| plt.xlabel('Predicted Label', fontsize=12) |
| plt.ylabel('True Label', fontsize=12) |
| plt.title('Confusion Matrix', fontsize=14, fontweight='bold') |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150) |
| plt.close() |
|
|
| def plot_predictions(model, test_loader, device, save_path, num_samples=20): |
| model.eval() |
| dataiter = iter(test_loader) |
| images, labels = next(dataiter) |
| images, labels = images.to(device), labels.to(device) |
| |
| rows = 4 |
| cols = num_samples // rows |
| fig, axes = plt.subplots(rows, cols, figsize=(15, 8)) |
| axes = axes.ravel() |
| |
| with torch.no_grad(): |
| outputs = model(images[:num_samples]) |
| _, predicted = torch.max(outputs, 1) |
| probs = torch.softmax(outputs, dim=1) |
| |
| for i in range(num_samples): |
| img = images[i].cpu().squeeze().numpy() |
| |
| |
| img = img * 0.3081 + 0.1307 |
| img = np.clip(img, 0, 1) |
| |
| axes[i].imshow(img, cmap='gray') |
| color = 'green' if predicted[i] == labels[i] else 'red' |
| confidence = probs[i][predicted[i]].item() * 100 |
| axes[i].set_title(f"Pred: {predicted[i].item()} ({confidence:.1f}%)\nTrue: {labels[i].item()}", |
| color=color, fontweight='bold', fontsize=9) |
| axes[i].axis('off') |
| |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150) |
| plt.close() |
|
|
| def evaluate_model(model, test_loader, device, logger, save_dir): |
| model.eval() |
| all_preds = [] |
| all_labels = [] |
| |
| with torch.no_grad(): |
| for images, labels in tqdm(test_loader, desc="Evaluating"): |
| images = images.to(device) |
| outputs = model(images) |
| _, predicted = torch.max(outputs, 1) |
| |
| all_preds.extend(predicted.cpu().numpy()) |
| all_labels.extend(labels.numpy()) |
| |
| all_preds = np.array(all_preds) |
| all_labels = np.array(all_labels) |
| |
| |
| accuracy = 100. * (all_preds == all_labels).sum() / len(all_labels) |
| logger.info(f"Test Accuracy: {accuracy:.2f}%") |
| |
| |
| report = classification_report(all_labels, all_preds, target_names=[str(i) for i in range(10)]) |
| logger.info(f"\nClassification Report:\n{report}") |
| |
| |
| report_path = Path(save_dir) / 'classification_report.txt' |
| with open(report_path, 'w') as f: |
| f.write(report) |
| |
| |
| cm_path = Path(save_dir) / 'confusion_matrix.png' |
| plot_confusion_matrix(all_labels, all_preds, cm_path) |
| logger.info(f"Confusion matrix saved to {cm_path}") |
| |
| return accuracy, all_preds, all_labels |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Enhanced MNIST Classifier with Advanced Features') |
| |
| |
| parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'], |
| help='Model architecture type') |
| parser.add_argument('--dropout-rate', type=float, default=0.3, help='Dropout rate') |
| |
| |
| parser.add_argument('--epochs', type=int, default=20, help='Number of epochs') |
| parser.add_argument('--batch-size', type=int, default=128, help='Batch size') |
| parser.add_argument('--lr', type=float, default=0.001, help='Initial learning rate') |
| parser.add_argument('--optimizer', type=str, default='adamw', |
| choices=['adam', 'sgd', 'adamw'], help='Optimizer choice') |
| parser.add_argument('--weight-decay', type=float, default=1e-4, help='Weight decay') |
| parser.add_argument('--scheduler', type=str, default='onecycle', |
| choices=['cosine', 'onecycle', 'step'], help='Learning rate scheduler') |
| parser.add_argument('--warmup-epochs', type=int, default=2, help='Number of warmup epochs') |
| |
| |
| parser.add_argument('--data-dir', type=str, default='./data', help='Data directory') |
| parser.add_argument('--val-split', type=float, default=0.1, help='Validation split ratio') |
| parser.add_argument('--num-workers', type=int, default=4, help='Number of data loading workers') |
| |
| |
| parser.add_argument('--early-stop-patience', type=int, default=7, |
| help='Early stopping patience') |
| parser.add_argument('--use-amp', action='store_true', help='Use automatic mixed precision') |
| |
| |
| parser.add_argument('--save-dir', type=str, default='./checkpoints', help='Save directory') |
| parser.add_argument('--log-dir', type=str, default='./runs', help='TensorBoard log directory') |
| parser.add_argument('--save-freq', type=int, default=5, help='Save checkpoint every N epochs') |
| parser.add_argument('--seed', type=int, default=42, help='Random seed') |
| |
| |
| parser.add_argument('--use-gpu', action='store_true', help='Use GPU if available') |
| |
| return parser.parse_args() |
|
|
| def main(): |
| args = parse_args() |
| |
| |
| set_seed(args.seed) |
| |
| |
| Path(args.save_dir).mkdir(parents=True, exist_ok=True) |
| Path(args.log_dir).mkdir(parents=True, exist_ok=True) |
| |
| |
| logger = setup_logging(args.save_dir) |
| logger.info(f"Arguments: {vars(args)}") |
| |
| |
| device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') |
| logger.info(f"Using device: {device}") |
| if device.type == 'cuda': |
| logger.info(f"GPU: {torch.cuda.get_device_name(0)}") |
| |
| |
| os.makedirs(args.data_dir, exist_ok=True) |
| |
| train_transform = transforms.Compose([ |
| transforms.RandomRotation(10), |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.1307,), (0.3081,)), |
| transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)) |
| ]) |
| |
| test_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.1307,), (0.3081,)) |
| ]) |
| |
| |
| full_train_dataset = datasets.MNIST(root=args.data_dir, train=True, download=True, transform=train_transform) |
| test_dataset = datasets.MNIST(root=args.data_dir, train=False, download=True, transform=test_transform) |
| |
| |
| val_size = int(len(full_train_dataset) * args.val_split) |
| train_size = len(full_train_dataset) - val_size |
| train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size]) |
| |
| logger.info(f"Train size: {train_size}, Val size: {val_size}, Test size: {len(test_dataset)}") |
| |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True if device.type == 'cuda' else False, |
| persistent_workers=True if args.num_workers > 0 else False |
| ) |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=True if device.type == 'cuda' else False, |
| persistent_workers=True if args.num_workers > 0 else False |
| ) |
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=True if device.type == 'cuda' else False, |
| persistent_workers=True if args.num_workers > 0 else False |
| ) |
| |
| |
| if args.model_type == 'cnn': |
| model = ConvNet(dropout_rate=args.dropout_rate).to(device) |
| else: |
| model = ImprovedNN(dropout_rate=args.dropout_rate).to(device) |
| |
| logger.info(f"Model: {args.model_type}") |
| logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
| |
| |
| criterion = nn.CrossEntropyLoss() |
| |
| if args.optimizer == 'adam': |
| optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| elif args.optimizer == 'adamw': |
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| else: |
| optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, |
| weight_decay=args.weight_decay, nesterov=True) |
| |
| |
| if args.scheduler == 'cosine': |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs) |
| elif args.scheduler == 'onecycle': |
| scheduler = optim.lr_scheduler.OneCycleLR( |
| optimizer, max_lr=args.lr * 10, |
| epochs=args.epochs - args.warmup_epochs, |
| steps_per_epoch=len(train_loader) |
| ) |
| else: |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) |
| |
| |
| trainer = Trainer(model, train_loader, val_loader, test_loader, |
| criterion, optimizer, scheduler, device, args, logger) |
| |
| |
| best_val_acc = trainer.train() |
| |
| |
| best_model_path = Path(args.save_dir) / 'best_model.pth' |
| checkpoint = torch.load(best_model_path, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| logger.info(f"Loaded best model from epoch {checkpoint['epoch']+1}") |
| |
| |
| logger.info("\n" + "="*70) |
| logger.info("Final Evaluation on Test Set") |
| logger.info("="*70) |
| test_acc, test_preds, test_labels = evaluate_model(model, test_loader, device, logger, args.save_dir) |
| |
| |
| history_path = Path(args.save_dir) / 'training_history.json' |
| curves_path = Path(args.save_dir) / 'training_curves.png' |
| plot_training_curves(history_path, curves_path) |
| logger.info(f"Training curves saved to {curves_path}") |
| |
| |
| pred_path = Path(args.save_dir) / 'predictions.png' |
| plot_predictions(model, test_loader, device, pred_path) |
| logger.info(f"Predictions saved to {pred_path}") |
| |
| |
| logger.info("\n" + "="*70) |
| logger.info("Model Loading Instructions:") |
| logger.info(f"from improved_mnist_classifier import {model.__class__.__name__}") |
| logger.info(f"model = {model.__class__.__name__}().to(device)") |
| logger.info(f"checkpoint = torch.load('{best_model_path}')") |
| logger.info(f"model.load_state_dict(checkpoint['model_state_dict'])") |
| logger.info(f"model.eval()") |
| logger.info("="*70) |
| |
| logger.info(f"\nTraining complete! Best Val Acc: {best_val_acc:.2f}%, Test Acc: {test_acc:.2f}%") |
|
|
| if __name__ == '__main__': |
| main() |