| import argparse
|
| import torch
|
| import matplotlib.pyplot as plt
|
| import requests
|
| from PIL import Image
|
| from torchvision import transforms
|
| import data_setup, model_builder
|
| from pathlib import Path
|
| import os
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("-i", "--image", help="string of url to the image", type=str)
|
| args = parser.parse_args()
|
|
|
| URL = args.image
|
|
|
| image_transform = transforms.Compose([
|
| transforms.Resize(size=(224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225])])
|
|
|
| IMAGE_PATH = Path("data") / "spoiled-fresh" / "FRUIT-16K"
|
|
|
| classes = sorted(entry.name for entry in os.scandir(IMAGE_PATH) if entry.is_dir())
|
|
|
|
|
| loaded_model = model_builder.create_model_baseline_effnetb2(out_feats=len(classes), device="cpu")
|
| loaded_model.load_state_dict(torch.load("models/effnetb2_fruitsvegs0_5_epochs.pt", weights_only=True))
|
|
|
| def pred_and_plot(model: torch.nn.Module,
|
| image_path: str,
|
| transform: transforms.Compose,
|
| class_names: list[str] = None):
|
|
|
| img = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
|
|
|
| transformed_img = transform(img)
|
|
|
| logits = model(transformed_img.unsqueeze(dim=0))
|
| pred = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
|
|
|
|
| title = f"{class_names[pred]} | {torch.softmax(logits, dim=-1).max():.3f}"
|
| plt.title(title)
|
| print(title)
|
|
|
| pred_and_plot(model=loaded_model, image_path=URL,
|
| transform=image_transform, class_names=classes) |