| """ |
| Inference script for making predictions with trained MNIST models |
| Usage: python inference.py --model-path checkpoints/best_model.pth --image-path my_digit.png |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| from PIL import Image |
| import argparse |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from pathlib import Path |
|
|
| |
| class ConvNet(nn.Module): |
| """Convolutional Neural Network for MNIST""" |
| def __init__(self, dropout_rate=0.3, num_classes=10): |
| super(ConvNet, self).__init__() |
| |
| self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
| self.bn1 = nn.BatchNorm2d(32) |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(64) |
| |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
| self.bn3 = nn.BatchNorm2d(128) |
| self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) |
| self.bn4 = nn.BatchNorm2d(128) |
| |
| self.pool = nn.MaxPool2d(2, 2) |
| self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5) |
| |
| self.fc1 = nn.Linear(128 * 7 * 7, 256) |
| self.bn5 = nn.BatchNorm1d(256) |
| self.dropout1 = nn.Dropout(dropout_rate) |
| |
| self.fc2 = nn.Linear(256, 128) |
| self.bn6 = nn.BatchNorm1d(128) |
| self.dropout2 = nn.Dropout(dropout_rate * 0.5) |
| |
| self.fc3 = nn.Linear(128, num_classes) |
|
|
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = torch.relu(x) |
| x = self.conv2(x) |
| x = self.bn2(x) |
| x = torch.relu(x) |
| x = self.pool(x) |
| x = self.dropout_conv(x) |
| |
| x = self.conv3(x) |
| x = self.bn3(x) |
| x = torch.relu(x) |
| x = self.conv4(x) |
| x = self.bn4(x) |
| x = torch.relu(x) |
| x = self.pool(x) |
| x = self.dropout_conv(x) |
| |
| x = x.view(x.size(0), -1) |
| |
| x = self.fc1(x) |
| x = self.bn5(x) |
| x = torch.relu(x) |
| x = self.dropout1(x) |
| |
| x = self.fc2(x) |
| x = self.bn6(x) |
| x = torch.relu(x) |
| x = self.dropout2(x) |
| |
| x = self.fc3(x) |
| return x |
|
|
| class ImprovedNN(nn.Module): |
| """Enhanced fully connected network""" |
| def __init__(self, input_size=784, hidden_sizes=[512, 256, 128], |
| num_classes=10, dropout_rate=0.3): |
| super(ImprovedNN, self).__init__() |
| |
| layers = [] |
| prev_size = input_size |
| |
| for i, hidden_size in enumerate(hidden_sizes): |
| layers.extend([ |
| nn.Linear(prev_size, hidden_size), |
| nn.BatchNorm1d(hidden_size), |
| nn.ReLU(), |
| nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5) |
| ]) |
| prev_size = hidden_size |
| |
| layers.append(nn.Linear(prev_size, num_classes)) |
| self.network = nn.Sequential(*layers) |
| |
| def forward(self, x): |
| x = x.view(x.size(0), -1) |
| return self.network(x) |
|
|
| def load_model(model_path, model_type='cnn', device='cpu'): |
| """Load a trained model from checkpoint""" |
| |
| checkpoint = torch.load(model_path, map_location=device) |
| |
| |
| if 'args' in checkpoint and 'model_type' in checkpoint['args']: |
| model_type = checkpoint['args']['model_type'] |
| |
| |
| if model_type == 'cnn': |
| model = ConvNet() |
| else: |
| model = ImprovedNN() |
| |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.to(device) |
| model.eval() |
| |
| print(f"✓ Loaded {model_type.upper()} model from {model_path}") |
| print(f" - Trained for {checkpoint.get('epoch', 'unknown')} epochs") |
| print(f" - Validation accuracy: {checkpoint.get('val_acc', 'unknown'):.2f}%") |
| |
| return model |
|
|
| def preprocess_image(image_path): |
| """Preprocess an image for inference""" |
| |
| img = Image.open(image_path).convert('L') |
| |
| |
| img = img.resize((28, 28), Image.Resampling.LANCZOS) |
| |
| |
| |
| |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.1307,), (0.3081,)) |
| ]) |
| |
| img_tensor = transform(img) |
| |
| |
| img_array = np.array(img) |
| |
| return img_tensor, img_array |
|
|
| def predict(model, image_tensor, device): |
| """Make prediction on a single image""" |
| |
| image_tensor = image_tensor.unsqueeze(0).to(device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(image_tensor) |
| probabilities = torch.softmax(outputs, dim=1) |
| confidence, predicted = torch.max(probabilities, 1) |
| |
| return predicted.item(), confidence.item(), probabilities.squeeze().cpu().numpy() |
|
|
| def visualize_prediction(image, predicted_digit, confidence, probabilities, save_path=None): |
| """Visualize the prediction with confidence scores""" |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
| |
| |
| ax1.imshow(image, cmap='gray') |
| ax1.set_title(f'Input Image\nPredicted: {predicted_digit} ({confidence*100:.1f}%)', |
| fontsize=14, fontweight='bold') |
| ax1.axis('off') |
| |
| |
| digits = np.arange(10) |
| colors = ['green' if i == predicted_digit else 'gray' for i in digits] |
| bars = ax2.bar(digits, probabilities * 100, color=colors, alpha=0.7) |
| |
| |
| for i, (bar, prob) in enumerate(zip(bars, probabilities)): |
| height = bar.get_height() |
| ax2.text(bar.get_x() + bar.get_width()/2., height, |
| f'{prob*100:.1f}%', |
| ha='center', va='bottom', fontsize=9) |
| |
| ax2.set_xlabel('Digit', fontsize=12) |
| ax2.set_ylabel('Confidence (%)', fontsize=12) |
| ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold') |
| ax2.set_xticks(digits) |
| ax2.set_ylim([0, 105]) |
| ax2.grid(True, alpha=0.3, axis='y') |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"✓ Visualization saved to {save_path}") |
| |
| plt.show() |
|
|
| def predict_batch(model, image_paths, device): |
| """Make predictions on multiple images""" |
| results = [] |
| |
| for image_path in image_paths: |
| print(f"\nProcessing: {image_path}") |
| |
| |
| img_tensor, img_array = preprocess_image(image_path) |
| |
| |
| predicted, confidence, probabilities = predict(model, img_tensor, device) |
| |
| results.append({ |
| 'image_path': image_path, |
| 'predicted': predicted, |
| 'confidence': confidence, |
| 'probabilities': probabilities |
| }) |
| |
| print(f" Prediction: {predicted} (Confidence: {confidence*100:.2f}%)") |
| |
| |
| top3_idx = np.argsort(probabilities)[-3:][::-1] |
| print(f" Top 3: ", end="") |
| for idx in top3_idx: |
| print(f"{idx}({probabilities[idx]*100:.1f}%) ", end="") |
| print() |
| |
| return results |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='MNIST Digit Recognition Inference') |
| parser.add_argument('--model-path', type=str, required=True, |
| help='Path to trained model checkpoint') |
| parser.add_argument('--image-path', type=str, |
| help='Path to input image (28x28 recommended, grayscale)') |
| parser.add_argument('--image-dir', type=str, |
| help='Directory containing multiple images to predict') |
| parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'], |
| help='Model architecture type (auto-detected from checkpoint if available)') |
| parser.add_argument('--save-viz', type=str, |
| help='Path to save visualization') |
| parser.add_argument('--use-gpu', action='store_true', |
| help='Use GPU if available') |
| |
| args = parser.parse_args() |
| |
| |
| device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') |
| print(f"Using device: {device}") |
| |
| |
| model = load_model(args.model_path, args.model_type, device) |
| |
| |
| if args.image_path: |
| print(f"\nProcessing single image: {args.image_path}") |
| |
| |
| img_tensor, img_array = preprocess_image(args.image_path) |
| |
| |
| predicted, confidence, probabilities = predict(model, img_tensor, device) |
| |
| print(f"\n{'='*50}") |
| print(f"Prediction: {predicted}") |
| print(f"Confidence: {confidence*100:.2f}%") |
| print(f"{'='*50}") |
| |
| |
| print("\nAll class probabilities:") |
| for digit in range(10): |
| print(f" {digit}: {probabilities[digit]*100:.2f}%") |
| |
| |
| save_path = args.save_viz if args.save_viz else 'prediction_visualization.png' |
| visualize_prediction(img_array, predicted, confidence, probabilities, save_path) |
| |
| |
| elif args.image_dir: |
| print(f"\nProcessing directory: {args.image_dir}") |
| |
| image_dir = Path(args.image_dir) |
| image_paths = list(image_dir.glob('*.png')) + list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.jpeg')) |
| |
| if not image_paths: |
| print("No images found in directory!") |
| return |
| |
| print(f"Found {len(image_paths)} images") |
| |
| results = predict_batch(model, [str(p) for p in image_paths], device) |
| |
| |
| print(f"\n{'='*50}") |
| print("Summary:") |
| print(f"{'='*50}") |
| for result in results: |
| print(f"{Path(result['image_path']).name}: {result['predicted']} ({result['confidence']*100:.1f}%)") |
| |
| else: |
| print("Please provide either --image-path or --image-dir") |
| return |
|
|
| if __name__ == '__main__': |
| main() |