|
|
| try: |
| import torch |
| import torch.nn as nn |
| except: |
| pass |
|
|
| try: |
| import torchvision as tv |
| import torchvision.transforms as T |
| import torchvision.transforms.functional as F |
| except: |
| pass |
|
|
| try: |
| import pytorch_lightning as pl |
| except: |
| pass |
| try: |
| import torchmetrics |
| import lpips |
| except: |
| from argparse import Namespace |
| torchmetrics = Namespace(Metric=object) |
| try: |
| import wandb |
| except: |
| pass |
|
|
| try: |
| import kornia |
| except: |
| pass |
|
|
| try: |
| import detectron2 |
| from detectron2 import model_zoo as _ |
| from detectron2 import engine as _ |
| from detectron2 import config as _ |
| from detectron2 import data as _ |
| from detectron2.utils import visualizer as _ |
| except: |
| pass |
|
|
| try: |
| from nvidia import dali |
| from nvidia.dali.plugin import pytorch as _ |
| except: |
| pass |
|
|
| try: |
| import cupy |
| except: |
| pass |
|
|
| try: |
| import skimage |
| from skimage import measure as _ |
| from skimage import color as _ |
| from skimage import segmentation as _ |
| from skimage import filters as _ |
| from scipy.spatial.transform import Rotation |
| except: |
| pass |
|
|
| |
| import math |
| from twodee_v0 import * |
| from PIL import Image |
| |
|
|
| |
|
|
| try: |
| |
| def cupy_launch(func, kernel): |
| return cupy.cuda.compile_with_cache(kernel).get_function(func) |
| except: |
| cupy_launch = lambda func,kernel: None |
|
|
| def reset_parameters(model): |
| for layer in model.children(): |
| if hasattr(layer, 'reset_parameters'): |
| layer.reset_parameters() |
| return model |
|
|
| def channel_squeeze(x, dim=1): |
| a = x.shape[:dim] |
| b = x.shape[dim+2:] |
| return x.reshape(*a, -1, *b) |
| def channel_unsqueeze(x, shape, dim=1): |
| a = x.shape[:dim] |
| b = x.shape[dim+1:] |
| return x.reshape(*a, *shape, *b) |
|
|
| def default_collate(items, device=None): |
| return to(dict(torch.utils.data.dataloader.default_collate(items)), device) |
| def to(x, device): |
| if device is None: |
| return x |
| if issubclass(x.__class__, dict): |
| return dict({ |
| k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k,v in x.items() |
| }) |
| if isinstance(x, torch.Tensor): |
| return x.to(device) |
| if isinstance(x, np.ndarray): |
| return torch.tensor(x).to(device) |
| assert 0, 'data not understood' |
|
|
| |
|
|
| class SSIMMetric(torchmetrics.Metric): |
| |
| def __init__(self, window_size=11, **kwargs): |
| super().__init__(**kwargs) |
| self.window_size = window_size |
| self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| self.idd = 0 |
| self.transform = T.ToPILImage() |
| return |
| def update(self, preds: torch.Tensor, target: torch.Tensor): |
|
|
| for i in range(preds.size()[0]): |
|
|
| pp = self.transform(preds[i]) |
| tt = self.transform(target[i]) |
| |
| |
| |
| |
| |
| self.idd += 1 |
| |
| |
| |
| |
| |
|
|
| |
| |
| ssss = calc_ssim(pp, tt) |
| |
| self.running_sum += ssss |
|
|
| |
| |
|
|
| |
| |
| |
|
|
|
|
| self.running_count += preds.size()[0] |
|
|
| |
| |
| |
| return |
| def compute(self): |
| return self.running_sum.float() / self.running_count |
|
|
| class SSIMMetricCPU(torchmetrics.Metric): |
| full_state_update=False |
| def __init__(self, window_size=11, **kwargs): |
| super().__init__(**kwargs) |
| self.window_size = window_size |
| self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| return |
| def update(self, preds: torch.Tensor, target: torch.Tensor): |
| ans = kornia.metrics.ssim(target, preds, self.window_size).mean((1,2,3)) |
| self.running_sum += ans.sum() |
| self.running_count += len(ans) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| return |
| def compute(self): |
| return self.running_sum / self.running_count |
|
|
| class PSNRMetric(torchmetrics.Metric): |
| |
| def __init__(self, data_range=1.0, **kwargs): |
| super().__init__(**kwargs) |
| self.data_range = torch.tensor(data_range) |
| self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| return |
| def update(self, preds: torch.Tensor, target: torch.Tensor): |
| ans = -10 * torch.log10( (target-preds).pow(2).mean((1,2,3)) ) |
| self.running_sum += 20*torch.log10(self.data_range) + ans.sum() |
| self.running_count += len(ans) |
| return |
| def compute(self): |
| return self.running_sum.float() / self.running_count |
| class PSNRMetricCPU(torchmetrics.Metric): |
| full_state_update=False |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| return |
| def update(self, preds: torch.Tensor, target: torch.Tensor): |
| ans = [ |
| skimage.metrics.peak_signal_noise_ratio( |
| p.permute(1,2,0).cpu().numpy(), |
| t.permute(1,2,0).cpu().numpy(), |
| |
| ) |
| for p,t in zip(preds, target) |
| ] |
| self.running_sum += sum(ans) |
| self.running_count += len(ans) |
| return |
| def compute(self): |
| return self.running_sum / self.running_count |
|
|
| class LPIPSMetric(torchmetrics.Metric): |
| full_state_update=False |
| def __init__(self, net_type='alex', **kwargs): |
| super().__init__(**kwargs) |
| self.net_type = net_type |
| assert self.net_type in ['alex', 'vgg', 'squeeze'] |
| self.model = lpips.LPIPS(net=self.net_type) |
| self.add_state('running_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| self.add_state('running_count', default=torch.tensor(0.0), dist_reduce_fx='sum') |
| return |
| def update(self, preds: torch.Tensor, target: torch.Tensor): |
| if preds.requires_grad: |
| ans = self.model(preds, target).mean((1,2,3)) |
| else: |
| with torch.no_grad(): |
| ans = self.model(preds, target).mean((1,2,3)) |
| self.running_sum += ans.sum() |
| self.running_count += len(ans) |
| return |
| def compute(self): |
| return self.running_sum.float() / self.running_count |
| class LPIPSLoss(nn.Module): |
| def __init__(self, net_type='alex', **kwargs): |
| super().__init__() |
| self.net_type = net_type |
| assert self.net_type in ['alex', 'vgg', 'squeeze'] |
| self.model = lpips.LPIPS(net=self.net_type, **kwargs) |
| return |
| def forward(self, preds: torch.Tensor, target: torch.Tensor): |
| ans = self.model(preds, target).mean((1,2,3)) |
| return ans |
|
|
| class LaplacianPyramidLoss(nn.Module): |
| def __init__(self, n_levels=3, colorspace=None, mode='l1'): |
| super().__init__() |
| self.n_levels = n_levels |
| self.colorspace = colorspace |
| self.mode = mode |
| assert self.mode in ['l1', 'l2'] |
| return |
| def forward(self, preds, target, force_levels=None, force_mode=None): |
| if self.colorspace=='lab': |
| preds = kornia.color.rgb_to_lab(preds.float()) |
| target = kornia.color.rgb_to_lab(target.float()) |
| lvls = self.n_levels if force_levels==None else force_levels |
| preds = kornia.geometry.transform.build_pyramid(preds, lvls) |
| target = kornia.geometry.transform.build_pyramid(target, lvls) |
| mode = self.mode if force_mode==None else force_mode |
| if mode=='l1': |
| ans = torch.stack([ |
| (p-t).abs().mean((1,2,3)) |
| for p,t in zip(preds,target) |
| ]).mean(0) |
| elif mode=='l2': |
| ans = torch.stack([ |
| (p-t).norm(dim=1, keepdim=True).mean((1,2,3)) |
| for p,t in zip(preds,target) |
| ]).mean(0) |
| else: |
| assert 0 |
| return ans |
|
|
| def make_grid(tensor, nrow=8, padding=2): |
| """ |
| Given a 4D mini-batch Tensor of shape (B x C x H x W), |
| or a list of images all of the same size, |
| makes a grid of images |
| """ |
| tensorlist = None |
| if isinstance(tensor, list): |
| tensorlist = tensor |
| numImages = len(tensorlist) |
| size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size()) |
| tensor = tensorlist[0].new(size) |
| for i in range(numImages): |
| tensor[i].copy_(tensorlist[i]) |
| if tensor.dim() == 2: |
| tensor = tensor.view(1, tensor.size(0), tensor.size(1)) |
| if tensor.dim() == 3: |
| if tensor.size(0) == 1: |
| tensor = torch.cat((tensor, tensor, tensor), 0) |
| return tensor |
| if tensor.dim() == 4 and tensor.size(1) == 1: |
| tensor = torch.cat((tensor, tensor, tensor), 1) |
| |
| nmaps = tensor.size(0) |
| xmaps = min(nrow, nmaps) |
| ymaps = int(math.ceil(nmaps / xmaps)) |
| height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) |
| grid = tensor.new(3, height * ymaps, width * xmaps).fill_(tensor.max()) |
| k = 0 |
| for y in range(ymaps): |
| for x in range(xmaps): |
| if k >= nmaps: |
| break |
| grid.narrow(1, y*height+1+padding//2,height-padding)\ |
| .narrow(2, x*width+1+padding//2, width-padding)\ |
| .copy_(tensor[k]) |
| k = k + 1 |
| return grid |
|
|
| def save_image(tensor, filename, nrow=8, padding=2): |
| """ |
| Saves a given Tensor into an image file. |
| If given a mini-batch tensor, will save the tensor as a grid of images. |
| """ |
|
|
| tensor = tensor.cpu() |
| grid = make_grid(tensor, nrow=nrow, padding=padding) |
| ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0,2).transpose(0,1).numpy() |
| im = Image.fromarray(ndarr) |
| im.save(filename) |
|
|
|
|
|
|