| """
|
| contains various utility functions for pytorch model training and saving
|
| """
|
| import torch
|
| from pathlib import Path
|
| import matplotlib.pyplot as plt
|
| import torchvision
|
| from PIL import Image
|
| from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
| def save_model(model: torch.nn.Module,
|
| target_dir: str,
|
| model_name: str):
|
| """Saves a pytorch model to a target directory
|
|
|
| Args:
|
| model: target pytorch model
|
| target_dir: string of target directory path to store the saved models
|
| model_name: a filename for the saved model. Should be included either ".pth" or ".pt" as
|
| the file extension.
|
| """
|
|
|
| target_dir_path = Path(target_dir)
|
| target_dir_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model name should end with .pt or .pth"
|
| model_save_path = target_dir_path / model_name
|
|
|
|
|
| print(f"[INFO] Saving model to: {model_save_path}")
|
| torch.save(obj=model.state_dict(), f=model_save_path)
|
|
|
| def pred_and_plot_image(
|
| model: torch.nn.Module,
|
| image_path: str,
|
| class_names: list[str] = None,
|
| transform=None,
|
| device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
| ):
|
| """Makes a prediction on a target image with a trained model and plots the image.
|
|
|
| Args:
|
| model (torch.nn.Module): trained PyTorch image classification model.
|
| image_path (str): filepath to target image.
|
| class_names (List[str], optional): different class names for target image. Defaults to None.
|
| transform (_type_, optional): transform of target image. Defaults to None.
|
| device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
|
|
|
| Returns:
|
| Matplotlib plot of target image and model prediction as title.
|
|
|
| Example usage:
|
| pred_and_plot_image(model=model,
|
| image="some_image.jpeg",
|
| class_names=["class_1", "class_2", "class_3"],
|
| transform=torchvision.transforms.ToTensor(),
|
| device=device)
|
| """
|
|
|
|
|
| img_list = Image.open(image_path)
|
|
|
|
|
|
|
|
|
|
|
| if transform:
|
| target_image = transform(img_list)
|
|
|
|
|
| model.to(device)
|
|
|
|
|
| model.eval()
|
| with torch.inference_mode():
|
|
|
| target_image = target_image.unsqueeze(dim=0)
|
|
|
|
|
| target_image_pred = model(target_image.to(device))
|
|
|
|
|
| target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
|
|
|
|
| target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
|
|
|
|
|
| plt.imshow(
|
| target_image.squeeze().permute(1, 2, 0)
|
| )
|
| if class_names:
|
| title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
| else:
|
| title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
| plt.title(title)
|
| plt.axis(False)
|
|
|
| def set_seeds(seed: int=42):
|
| """Sets random sets for torch operations.
|
|
|
| Args:
|
| seed (int, optional): Random seed to set. Defaults to 42.
|
| """
|
|
|
| torch.manual_seed(seed)
|
|
|
| torch.cuda.manual_seed(seed)
|
|
|
|
|
| def create_writer(experiment_name: str, model_name: str, extra: str=None) -> torch.utils.tensorboard.writer.SummaryWriter():
|
| """
|
| creates a torch.utils.tensorboard.writer.SummaryWriter() instance saving to a
|
| specific log_dir.
|
|
|
| log_dir is a combination of runs/timestamp/experiment_name/model_name/extra.
|
|
|
| where timestamp is the current date in YYYY-MM-DD format.
|
|
|
| Args:
|
| experiment_name (str): Name of experiment
|
| model_name (str): model name
|
| extra (str, optional): anything extra to add to the directory. Defaults is None
|
|
|
| Returns:
|
| torch.utils.tensorboard.writer.SummaryWriter(): Instance of a writer saving to log_dir
|
|
|
| Examples usage:
|
| this is gonna create writer saving to "runs/2022-06-04/data_10_percent/effnetb2/5_epochs"
|
|
|
| writer = create_writer(experiment_name="data_10_percent", model_name="effnetb2", extra="5_epochs")
|
|
|
| This is the same as:
|
| writer = SummaryWriter(log_dir="runs/2022-06-04/data_10_percent/effnetb2/5_epochs")
|
| """
|
|
|
| from datetime import datetime
|
| import os
|
|
|
|
|
| timestamp = datetime.now().strftime("%Y-%m-%d")
|
|
|
| if extra:
|
|
|
| log_dir = os.path.join("runs", timestamp, experiment_name, model_name, extra)
|
| else:
|
| log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
|
|
|
| print(f"[INFO] Created SummaryWriter(), saving to: {log_dir}")
|
|
|
| return SummaryWriter(log_dir=log_dir)
|
|
|