| import torch
|
| from transformers import SwinForImageClassification, AutoImageProcessor
|
| from PIL import Image
|
| import joblib
|
| import numpy as np
|
| import matplotlib.pyplot as plt
|
| import cv2
|
| import os
|
| from pathlib import Path
|
|
|
| class CoinPredictor:
|
| def __init__(self, model_dir='model_checkpoints', top_n=10):
|
| """
|
| Initialize the predictor with trained model and necessary components.
|
|
|
| Args:
|
| model_dir (str): Directory containing the saved model and label encoder
|
| top_n (int): Number of top predictions to return
|
| """
|
| self.top_n = top_n
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| self.crop_percentage = 0.15
|
|
|
|
|
| self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
|
|
|
|
|
| model_path = os.path.join(model_dir, 'best_model')
|
| self.model = SwinForImageClassification.from_pretrained(model_path)
|
| self.model.to(self.device)
|
| self.model.eval()
|
|
|
|
|
| encoder_path = os.path.join(model_dir, 'label_encoder.joblib')
|
| self.label_encoder = joblib.load(encoder_path)
|
|
|
| print(f"Model loaded and running on {self.device}")
|
|
|
| def crop_center(self, image):
|
| """
|
| Crop the center portion of the image.
|
| """
|
| h, w = image.shape[:2]
|
| crop_h = int(h * self.crop_percentage)
|
| crop_w = int(w * self.crop_percentage)
|
|
|
| return image[crop_h:h-crop_h, crop_w:w-crop_w]
|
|
|
| def preprocess_image(self, image_path):
|
| """
|
| Preprocess a single image for prediction.
|
| """
|
|
|
| image = cv2.imread(image_path)
|
| if image is None:
|
| raise ValueError(f"Could not load image: {image_path}")
|
|
|
|
|
| image = self.crop_center(image)
|
|
|
|
|
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
|
| image = Image.fromarray(image)
|
|
|
| return image
|
|
|
| def predict(self, image_path):
|
| """
|
| Make prediction for a single image.
|
|
|
| Args:
|
| image_path (str): Path to the image file
|
|
|
| Returns:
|
| list of tuples: (label, probability) for top N predictions
|
| """
|
|
|
| image = self.preprocess_image(image_path)
|
|
|
|
|
| inputs = self.image_processor(image, return_tensors="pt")
|
| inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
|
| with torch.no_grad():
|
| outputs = self.model(**inputs)
|
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
|
|
|
|
| top_probs, top_indices = torch.topk(probabilities[0], self.top_n)
|
|
|
|
|
| predictions = []
|
| for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
|
| label = self.label_encoder.inverse_transform([idx])[0]
|
| predictions.append((label, float(prob)))
|
|
|
| return predictions, image
|
|
|
| def visualize_prediction(self, image_path, predictions, reference_dir="all_coins_cropped"):
|
| """
|
| Visualize the input image and top N matching reference images with probabilities.
|
|
|
| Args:
|
| image_path (str): Path to the query image
|
| predictions (tuple): (predictions, preprocessed_image)
|
| reference_dir (str): Directory containing reference images
|
| """
|
| predictions, processed_image = predictions
|
|
|
|
|
| n_cols = 4
|
| n_rows = (self.top_n + 3) // n_cols
|
|
|
|
|
| fig = plt.figure(figsize=(15, 4 * n_rows))
|
|
|
|
|
| plt.subplot(n_rows, n_cols, 1)
|
| original_img = cv2.imread(image_path)
|
| original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
| plt.imshow(original_img)
|
| plt.title("Original Query")
|
| plt.axis('off')
|
|
|
|
|
| plt.subplot(n_rows, n_cols, 2)
|
| plt.imshow(processed_image)
|
| plt.title("Processed Query")
|
| plt.axis('off')
|
|
|
|
|
| for i, (label, prob) in enumerate(predictions, 3):
|
|
|
| ref_path = None
|
| for ext in ['.jpg', '.jpeg', '.png']:
|
| test_path = os.path.join(reference_dir, label + ext)
|
| if os.path.exists(test_path):
|
| ref_path = test_path
|
| break
|
|
|
| if ref_path:
|
| plt.subplot(n_rows, n_cols, i)
|
| ref_img = cv2.imread(ref_path)
|
| ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
|
| plt.imshow(ref_img)
|
| plt.title(f"{label}\n{prob:.1%}")
|
| plt.axis('off')
|
|
|
| plt.tight_layout()
|
| plt.show()
|
|
|
| def main():
|
|
|
| predictor = CoinPredictor()
|
|
|
|
|
| image_path = input("Enter the path to the coin image: ")
|
|
|
| try:
|
|
|
| predictions = predictor.predict(image_path)
|
|
|
|
|
| print("\nPredictions:")
|
| for i, (label, prob) in enumerate(predictions[0], 1):
|
| print(f"{i}. {label}: {prob:.1%}")
|
|
|
|
|
| predictor.visualize_prediction(image_path, predictions)
|
|
|
| except Exception as e:
|
| print(f"Error: {e}")
|
|
|
| if __name__ == "__main__":
|
| main() |