| import torch |
| import torch.nn as nn |
| import os |
| import time |
| from tools import mutils |
|
|
| saved_grad = None |
| saved_name = None |
|
|
| base_url = './results' |
| os.makedirs(base_url, exist_ok=True) |
|
|
|
|
| def normalize_tensor_mm(tensor): |
| return (tensor - tensor.min()) / (tensor.max() - tensor.min()) |
|
|
|
|
| def normalize_tensor_sigmoid(tensor): |
| return nn.functional.sigmoid(tensor) |
|
|
|
|
| def save_image(tensor, name=None, save_path=None, exit_flag=False, timestamp=False, nrow=4, split_dir=None): |
| if split_dir: |
| _base_url = os.path.join(base_url, split_dir) |
| else: |
| _base_url = base_url |
| os.makedirs(_base_url, exist_ok=True) |
| import torchvision.utils as vutils |
| grid = vutils.make_grid(tensor.detach().cpu(), nrow=nrow) |
|
|
| if save_path: |
| vutils.save_image(grid, save_path) |
| else: |
| if timestamp: |
| vutils.save_image(grid, f'{_base_url}/{name}_{mutils.get_timestamp()}.png') |
| else: |
| vutils.save_image(grid, f'{_base_url}/{name}.png') |
| if exit_flag: |
| exit(0) |
|
|
|
|
| def save_feature(tensor, name, exit_flag=False, timestamp=False): |
| import torchvision.utils as vutils |
| |
| tensors = [tensor] |
| titles = ['original', 'min-max', 'sigmoid'] |
| if timestamp: |
| name += '_' + str(time.time()).replace('.', '') |
|
|
| for index, tensor in enumerate(tensors): |
| _data = tensor.detach().cpu().squeeze(0).unsqueeze(1) |
| num_per_row = 4 |
| if _data.shape[0] / 4 > 4: |
| num_per_row = int(_data.shape[0] / 4) |
| num_per_row = 8 |
| grid = vutils.make_grid(_data, nrow=num_per_row) |
| vutils.save_image(grid, f'{base_url}/{name}_{titles[index]}.png') |
| print(f'{base_url}/{name}_{titles[index]}.png') |
| if exit_flag: |
| exit(0) |
|
|
|
|
| def save(tensor, name, exit_flag=False): |
| import torchvision.utils as vutils |
| grid = vutils.make_grid(tensor.detach().cpu().squeeze(0).unsqueeze(1), nrow=4) |
| |
| |
| vutils.save_image(grid, f'{base_url}/{name}.png') |
| if exit_flag: |
| exit(0) |
|
|
|
|
| def save_grid_direct(grad, name): |
| grad = grad.view(1, 8, 320, 320) * 255 / (320 * 320) |
| |
| save(grad.clamp(0, 255), name) |
|
|
| module_grad = grad.clamp(-200, 200) |
| print(module_grad.min().item(), module_grad.max().item(), module_grad.mean().item()) |
| module_grad_flat = module_grad.flatten() |
| print(name, len(module_grad_flat[module_grad_flat < 0]) / len(module_grad_flat), |
| len(module_grad_flat[module_grad_flat < 0]), len(module_grad_flat[module_grad_flat == 0])) |
| import matplotlib.pyplot as plt |
| import numpy as np |
| y, x = np.histogram(module_grad.cpu().flatten().numpy(), bins=50, density=True) |
| |
| |
| |
| |
| |
| plt.bar(x[:-1], y) |
| |
| |
| plt.show() |
|
|
|
|
| def save_grid(grad, name, exit_flag=False): |
| global saved_grad, saved_name |
| print(grad.shape) |
| if saved_grad is None: |
| print(name) |
| saved_grad = grad |
| saved_name = name |
| else: |
| |
| |
|
|
| module_grad = grad / (saved_grad + 1e-7) |
| print(module_grad.max()) |
| save(module_grad.clamp(0, 255) / 255., name) |
|
|
| module_grad = module_grad.clamp(-300, 300) |
| print(module_grad.min().item(), module_grad.max().item(), module_grad.mean().item()) |
| module_grad_flat = module_grad.flatten() |
| print(name, len(module_grad_flat[module_grad_flat < 0]) / len(module_grad_flat), |
| len(module_grad_flat[module_grad_flat < 0]), len(module_grad_flat[module_grad_flat == 0])) |
| import matplotlib.pyplot as plt |
| import numpy as np |
| y, x = np.histogram(module_grad.cpu().flatten().numpy(), bins=50, density=True) |
| |
| |
| |
| |
| |
| plt.bar(x[:-1], y) |
| |
| |
| plt.show() |
| exit(0) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if exit_flag: |
| exit(0) |
|
|
|
|
| def show_grid(grid, name, exit_flag=False): |
| import torchvision.utils as vutils |
| import torchvision.transforms as vtrans |
| import matplotlib.pyplot as plt |
|
|
| grid = (grid - grid.min()) / (grid.max() - grid.min()) |
| grid = vutils.make_grid(grid.cpu().squeeze(0).unsqueeze(1), nrow=4) |
|
|
| |
| plt.imshow(vtrans.ToPILImage()(grid)) |
| plt.title(name) |
| plt.show() |
| |
| if exit_flag: |
| exit(0) |
|
|
|
|
| def show_img(img, name, exit_flag=False): |
| import torchvision.utils as vutils |
| import torchvision.transforms as vtrans |
| import matplotlib.pyplot as plt |
|
|
| grid = vutils.make_grid(img.cpu().squeeze(0)) |
|
|
| |
| plt.imshow(vtrans.ToPILImage()(grid)) |
| plt.title(name) |
| plt.show() |
| |
| if exit_flag: |
| exit(0) |
|
|
|
|
| class SaverBlock(nn.Module): |
| def __init__(self): |
| super(SaverBlock, self).__init__() |
|
|
| def forward(self, x): |
| save_feature(x[0], 'intermediate_', timestamp=True) |
| return x |
|
|