| import os |
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| try: |
| from kornia.morphology import opening |
| except ImportError: |
| from kornia.morphology import open as opening |
|
|
| from torchvision import transforms |
| from torchvision.utils import make_grid, save_image |
|
|
| from typing import Any |
|
|
| def exist(val: Any) -> bool: |
| return val is not None |
|
|
| def morph_open(x: torch.Tensor, k: int) -> torch.Tensor: |
| if k==0: |
| return x |
| else: |
| with torch.no_grad(): |
| return opening(x, torch.ones(k,k,device=x.device)) |
|
|
| def make_grid_images(images: list[torch.Tensor], **kwargs) -> torch.Tensor: |
| concatenated_images = torch.cat(images, dim=3) |
| grid_concatenated = make_grid(concatenated_images, **kwargs) |
| return grid_concatenated |
|
|
| def save_images(images: tuple[torch.Tensor, torch.Tensor], path: str, **kwargs) -> None: |
| gen, real = images |
| concatenated_images = torch.cat((gen, real), dim=3) |
| grid_concatenated = make_grid(concatenated_images, **kwargs) |
|
|
| ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() |
| ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) |
|
|
| save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) |
|
|
| def save_triplet(images: tuple[torch.Tensor, ...], path: str, **kwargs) -> None: |
| concatenated_images = torch.cat(images, dim=3) |
| grid_concatenated = make_grid(concatenated_images, **kwargs) |
| |
| ndarr_concatenated = grid_concatenated.permute(1, 2, 0).to("cpu").numpy() |
| ndarr_concatenated = (ndarr_concatenated * 255).astype(np.uint8) |
|
|
| save_image(torch.from_numpy(ndarr_concatenated).permute(2, 0, 1) / 255, path) |
|
|
| def plot_images(images: torch.Tensor) -> None: |
| plt.figure(figsize=(32, 32)) |
| plt.imshow(torch.cat([ |
| torch.cat([i for i in images.cpu()], dim=-1), |
| ], dim=-2).permute(1, 2, 0).cpu()) |
| plt.show() |
|
|
| def make_graphic(metric_name: str, metrics: list[torch.Tensor], path: str) -> None: |
| plt.figure(figsize=(32, 32)) |
| metrics = [m.cpu().numpy() for m in metrics] |
| plt.plot(metrics) |
| plt.title(metric_name) |
| plt.xlabel("Epoch") |
| plt.ylabel(metric_name) |
| path = os.path.join(path, f"{metric_name}.png") |
| plt.savefig(path) |
| plt.close() |
|
|
| def norm( |
| img: torch.Tensor, |
| mean: list[float] = [0.5, 0.5, 0.5], |
| std: list[float] = [0.5, 0.5, 0.5] |
| ) -> torch.Tensor: |
| normalize = transforms.Normalize(mean, std) |
| return normalize(img) |
|
|
| def denorm( |
| img: torch.Tensor, |
| mean: list[float] = [0.5, 0.5, 0.5], |
| std: list[float] = [0.5, 0.5, 0.5] |
| ) -> torch.Tensor: |
| mean = torch.tensor(mean, device=img.device) |
| std = torch.tensor(std, device=img.device) |
| return img*std[None][...,None,None] + mean[None][...,None,None] |