| import numpy as np |
| import pickle |
| from torch.utils.data import Dataset, DataLoader |
| import os |
| import torch |
| from copy import deepcopy |
| from blimpy import Waterfall |
| from tqdm import tqdm |
| from copy import deepcopy |
| from sigpyproc.readers import FilReader |
| from torch import nn |
|
|
|
|
| def renorm_batched(data): |
| mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
| std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
| standardized_data = (data - mean) / std |
| return standardized_data |
|
|
| def transform_batched(data): |
| copy_data = data.detach().clone() |
| rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
| mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
| masks_rms = [-1, 5] |
| |
| |
| num_masks = len(masks_rms) + 1 |
| new_data = torch.zeros((num_masks, *data.shape), device=data.device) |
|
|
| |
| new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10)) |
| for i, scale in enumerate(masks_rms, start=1): |
| copy_data = data.detach().clone() |
| |
| |
| if scale < 0: |
| ind = copy_data < abs(scale) * rms + mean |
| else: |
| ind = copy_data > scale * rms + mean |
| copy_data[ind] = 0 |
| |
| |
| new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10)) |
| |
| |
| new_data = new_data.type(torch.float32) |
|
|
| |
| slices = torch.chunk(new_data, 8, dim=-1) |
| new_data = torch.stack(slices, dim=2) |
| new_data = torch.swapaxes(new_data, 0,1) |
| |
| new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) |
| return new_data |
|
|
| class preproc_flip(nn.Module): |
| def forward(self, x, flip=True): |
| template = transform_batched(torch.flip(x, dims = (-2,))) |
| return template |
|
|
| class preproc(nn.Module): |
| def forward(self, x, flip=True): |
| template = transform_batched(x) |
| return template |
| |