| import tqdm |
| import torch |
|
|
|
|
| class TorchScaler(torch.nn.Module): |
| """ |
| This torch module implements scaling for input tensors, both instance based |
| and dataset-wide statistic based. |
| |
| Args: |
| statistic: str, (default='dataset'), represent how to compute the statistic for normalisation. |
| Choice in {'dataset', 'instance'}. |
| 'dataset' needs to be 'fit()' with a dataloader of the dataset. |
| 'instance' apply the normalisation at an instance-level, so compute the statitics on the instance |
| specified, it can be a clip or a batch. |
| normtype: str, (default='standard') the type of normalisation to use. |
| Choice in {'standard', 'mean', 'minmax'}. 'standard' applies a classic normalisation with mean and standard |
| deviation. 'mean' substract the mean to the data. 'minmax' substract the minimum of the data and divide by |
| the difference between max and min. |
| """ |
|
|
| def __init__(self, statistic="dataset", normtype="standard", dims=(1, 2), eps=1e-8): |
| super(TorchScaler, self).__init__() |
| assert statistic in ["dataset", "instance", None] |
| assert normtype in ["standard", "mean", "minmax", None] |
| if statistic == "dataset" and normtype == "minmax": |
| raise NotImplementedError( |
| "statistic==dataset and normtype==minmax is not currently implemented." |
| ) |
| self.statistic = statistic |
| self.normtype = normtype |
| self.dims = dims |
| self.eps = eps |
|
|
| def load_state_dict(self, state_dict, strict=True): |
| if self.statistic == "dataset": |
| super(TorchScaler, self).load_state_dict(state_dict, strict) |
|
|
| def _load_from_state_dict( |
| self, |
| state_dict, |
| prefix, |
| local_metadata, |
| strict, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ): |
| if self.statistic == "dataset": |
| super(TorchScaler, self)._load_from_state_dict( |
| state_dict, |
| prefix, |
| local_metadata, |
| strict, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ) |
|
|
| def fit(self, dataloader, transform_func=lambda x: x[0]): |
| """ |
| Scaler fitting |
| |
| Args: |
| dataloader (DataLoader): training data DataLoader |
| transform_func (lambda function, optional): Transforms applied to the data. |
| Defaults to lambdax:x[0]. |
| """ |
| indx = 0 |
| for batch in tqdm.tqdm(dataloader): |
|
|
| feats = transform_func(batch) |
| if indx == 0: |
| mean = torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0) |
| mean_squared = ( |
| torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0) |
| ) |
| else: |
| mean += torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0) |
| mean_squared += ( |
| torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0) |
| ) |
| indx += 1 |
|
|
| mean /= indx |
| mean_squared /= indx |
|
|
| self.register_buffer("mean", mean) |
| self.register_buffer("mean_squared", mean_squared) |
|
|
| def forward(self, tensor): |
|
|
| if self.statistic is None or self.normtype is None: |
| return tensor |
|
|
| if self.statistic == "dataset": |
| assert hasattr(self, "mean") and hasattr( |
| self, "mean_squared" |
| ), "TorchScaler should be fit before used if statistics=dataset" |
| assert tensor.ndim == self.mean.ndim, "Pre-computed statistics " |
| if self.normtype == "mean": |
| return tensor - self.mean |
| elif self.normtype == "standard": |
| std = torch.sqrt(self.mean_squared - self.mean ** 2) |
| return (tensor - self.mean) / (std + self.eps) |
| else: |
| raise NotImplementedError |
|
|
| else: |
| if self.normtype == "mean": |
| return tensor - torch.mean(tensor, self.dims, keepdim=True) |
| elif self.normtype == "standard": |
| return (tensor - torch.mean(tensor, self.dims, keepdim=True)) / ( |
| torch.std(tensor, self.dims, keepdim=True) + self.eps |
| ) |
| elif self.normtype == "minmax": |
| return (tensor - torch.amin(tensor, dim=self.dims, keepdim=True)) / ( |
| torch.amax(tensor, dim=self.dims, keepdim=True) |
| - torch.amin(tensor, dim=self.dims, keepdim=True) |
| + self.eps |
| ) |
|
|