| """ |
| @inproceedings{guo2022online, |
| title={Online continual learning through mutual information maximization}, |
| author={Guo, Yiduo and Liu, Bing and Zhao, Dongyan}, |
| booktitle={International Conference on Machine Learning}, |
| pages={8109--8126}, |
| year={2022}, |
| organization={PMLR} |
| } |
| https://proceedings.mlr.press/v162/guo22g.html |
| |
| Code Reference: |
| https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| |
| We referred to the original author's code implementation and performed structural refactoring. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from copy import deepcopy |
| from core.model.buffer.onlinebuffer import OnlineBuffer |
| import math |
| import numbers |
| import numpy as np |
| from torch.autograd import Function |
| import torch.distributed as dist |
| import diffdist.functional as distops |
| from torchvision import transforms |
|
|
| if torch.__version__ >= '1.4.0': |
| kwargs = {'align_corners': False} |
| else: |
| kwargs = {} |
|
|
| |
| import math |
| import numbers |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Function |
|
|
| if torch.__version__ >= '1.4.0': |
| kwargs = {'align_corners': False} |
| else: |
| kwargs = {} |
|
|
|
|
| def rgb2hsv(rgb): |
| """Convert a 4-d RGB tensor to the HSV counterpart. |
| |
| Here, we compute hue using atan2() based on the definition in [1], |
| instead of using the common lookup table approach as in [2, 3]. |
| Those values agree when the angle is a multiple of 30°, |
| otherwise they may differ at most ~1.2°. |
| |
| References |
| [1] https://en.wikipedia.org/wiki/Hue |
| [2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html |
| [3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212 |
| """ |
|
|
| r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] |
|
|
| Cmax = rgb.max(1)[0] |
| Cmin = rgb.min(1)[0] |
| delta = Cmax - Cmin |
|
|
| hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b) |
| hue = (hue % (2 * math.pi)) / (2 * math.pi) |
| saturate = delta / Cmax |
| value = Cmax |
| hsv = torch.stack([hue, saturate, value], dim=1) |
| hsv[~torch.isfinite(hsv)] = 0. |
| return hsv |
|
|
|
|
| def hsv2rgb(hsv): |
| """Convert a 4-d HSV tensor to the RGB counterpart. |
| |
| >>> %timeit hsv2rgb(hsv) |
| 2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
| >>> %timeit rgb2hsv_fast(rgb) |
| 298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) |
| >>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6) |
| True |
| |
| References |
| [1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative |
| """ |
| h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]] |
| c = v * s |
|
|
| n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1) |
| k = (n + h * 6) % 6 |
| t = torch.min(k, 4 - k) |
| t = torch.clamp(t, 0, 1) |
|
|
| return v - c * t |
|
|
|
|
| class RandomResizedCropLayer(nn.Module): |
| def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)): |
| ''' |
| Inception Crop |
| size (tuple): size of fowarding image (C, W, H) |
| scale (tuple): range of size of the origin size cropped |
| ratio (tuple): range of aspect ratio of the origin aspect ratio cropped |
| ''' |
| super(RandomResizedCropLayer, self).__init__() |
|
|
| _eye = torch.eye(2, 3) |
| self.size = size |
| self.register_buffer('_eye', _eye) |
| self.scale = scale |
| self.ratio = ratio |
|
|
| def forward(self, inputs, whbias=None): |
| _device = inputs.device |
| N = inputs.size(0) |
| _theta = self._eye.repeat(N, 1, 1) |
|
|
| if whbias is None: |
| whbias = self._sample_latent(inputs) |
|
|
| _theta[:, 0, 0] = whbias[:, 0] |
| _theta[:, 1, 1] = whbias[:, 1] |
| _theta[:, 0, 2] = whbias[:, 2] |
| _theta[:, 1, 2] = whbias[:, 3] |
|
|
| grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) |
| output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) |
|
|
| |
| |
|
|
| return output |
|
|
| def _clamp(self, whbias): |
|
|
| w = whbias[:, 0] |
| h = whbias[:, 1] |
| w_bias = whbias[:, 2] |
| h_bias = whbias[:, 3] |
|
|
| |
| w = torch.clamp(w, *self.scale) |
| h = torch.clamp(h, *self.scale) |
|
|
| |
| w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h) |
| w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w) |
|
|
| |
| w_bias = w - 1 + torch.relu(w_bias - w + 1) |
| w_bias = 1 - w - torch.relu(1 - w - w_bias) |
|
|
| h_bias = h - 1 + torch.relu(h_bias - h + 1) |
| h_bias = 1 - h - torch.relu(1 - h - h_bias) |
|
|
| whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t() |
|
|
| return whbias |
|
|
| def _sample_latent(self, inputs): |
|
|
| _device = inputs.device |
| N, _, width, height = inputs.shape |
|
|
| |
| area = width * height |
| target_area = np.random.uniform(*self.scale, N * 10) * area |
| log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) |
| aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10)) |
|
|
| |
| w = np.round(np.sqrt(target_area * aspect_ratio)) |
| h = np.round(np.sqrt(target_area / aspect_ratio)) |
| cond = (0 < w) * (w <= width) * (0 < h) * (h <= height) |
| w = w[cond] |
| h = h[cond] |
| cond_len = w.shape[0] |
| if cond_len >= N: |
| w = w[:N] |
| h = h[:N] |
| else: |
| w = np.concatenate([w, np.ones(N - cond_len) * width]) |
| h = np.concatenate([h, np.ones(N - cond_len) * height]) |
|
|
| w_bias = np.random.randint(w - width, width - w + 1) / width |
| h_bias = np.random.randint(h - height, height - h + 1) / height |
| w = w / width |
| h = h / height |
|
|
| whbias = np.column_stack([w, h, w_bias, h_bias]) |
| whbias = torch.tensor(whbias, device=_device) |
|
|
| return whbias |
|
|
|
|
| class HorizontalFlipRandomCrop(nn.Module): |
| def __init__(self, max_range): |
| super(HorizontalFlipRandomCrop, self).__init__() |
| self.max_range = max_range |
| _eye = torch.eye(2, 3) |
| self.register_buffer('_eye', _eye) |
|
|
| def forward(self, input, sign=None, bias=None, rotation=None): |
| _device = input.device |
| N = input.size(0) |
| _theta = self._eye.repeat(N, 1, 1) |
|
|
| if sign is None: |
| sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1 |
| if bias is None: |
| bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range) |
| _theta[:, 0, 0] = sign |
| _theta[:, :, 2] = bias |
|
|
| if rotation is not None: |
| _theta[:, 0:2, 0:2] = rotation |
|
|
| grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device) |
| output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs) |
|
|
| return output |
|
|
| def _sample_latent(self, N, device=None): |
| sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1 |
| bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range) |
| return sign, bias |
|
|
|
|
| class Rotation(nn.Module): |
| def __init__(self, max_range = 4): |
| super(Rotation, self).__init__() |
| self.max_range = max_range |
| self.prob = 0.5 |
|
|
| def forward(self, input, aug_index=None): |
| _device = input.device |
| |
| _, _, H, W = input.size() |
|
|
| if aug_index is None: |
| aug_index = np.random.randint(4) |
|
|
| output = torch.rot90(input, aug_index, (2, 3)) |
|
|
| _prob = input.new_full((input.size(0),), self.prob) |
| _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) |
| output = _mask * input + (1-_mask) * output |
|
|
| else: |
| aug_index = aug_index % self.max_range |
| output = torch.rot90(input, aug_index, (2, 3)) |
|
|
| return output |
|
|
|
|
| class CutPerm(nn.Module): |
| def __init__(self, max_range = 4): |
| super(CutPerm, self).__init__() |
| self.max_range = max_range |
| self.prob = 0.5 |
|
|
| def forward(self, input, aug_index=None): |
| _device = input.device |
|
|
| _, _, H, W = input.size() |
|
|
| if aug_index is None: |
| aug_index = np.random.randint(4) |
|
|
| output = self._cutperm(input, aug_index) |
|
|
| _prob = input.new_full((input.size(0),), self.prob) |
| _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) |
| output = _mask * input + (1 - _mask) * output |
|
|
| else: |
| aug_index = aug_index % self.max_range |
| output = self._cutperm(input, aug_index) |
|
|
| return output |
|
|
| def _cutperm(self, inputs, aug_index): |
|
|
| _, _, H, W = inputs.size() |
| h_mid = int(H / 2) |
| w_mid = int(W / 2) |
|
|
| jigsaw_h = aug_index // 2 |
| jigsaw_v = aug_index % 2 |
|
|
| if jigsaw_h == 1: |
| inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2) |
| if jigsaw_v == 1: |
| inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3) |
|
|
| return inputs |
|
|
|
|
| class HorizontalFlipLayer(nn.Module): |
| def __init__(self): |
| """ |
| img_size : (int, int, int) |
| Height and width must be powers of 2. E.g. (32, 32, 1) or |
| (64, 128, 3). Last number indicates number of channels, e.g. 1 for |
| grayscale or 3 for RGB |
| """ |
| super(HorizontalFlipLayer, self).__init__() |
|
|
| _eye = torch.eye(2, 3) |
| self.register_buffer('_eye', _eye) |
|
|
| def forward(self, inputs): |
| _device = inputs.device |
|
|
| N = inputs.size(0) |
| _theta = self._eye.repeat(N, 1, 1) |
| r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1 |
| _theta[:, 0, 0] = r_sign |
| grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) |
| inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) |
|
|
| return inputs |
|
|
|
|
| class RandomColorGrayLayer(nn.Module): |
| def __init__(self, p): |
| super(RandomColorGrayLayer, self).__init__() |
| self.prob = p |
|
|
| _weight = torch.tensor([[0.299, 0.587, 0.114]]) |
| self.register_buffer('_weight', _weight.view(1, 3, 1, 1)) |
|
|
| def forward(self, inputs, aug_index=None): |
|
|
| if aug_index == 0: |
| return inputs |
|
|
| l = F.conv2d(inputs, self._weight) |
| gray = torch.cat([l, l, l], dim=1) |
|
|
| if aug_index is None: |
| _prob = inputs.new_full((inputs.size(0),), self.prob) |
| _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) |
|
|
| gray = inputs * (1 - _mask) + gray * _mask |
|
|
| return gray |
|
|
|
|
| class ColorJitterLayer(nn.Module): |
| def __init__(self, p, brightness, contrast, saturation, hue): |
| super(ColorJitterLayer, self).__init__() |
| self.prob = p |
| self.brightness = self._check_input(brightness, 'brightness') |
| self.contrast = self._check_input(contrast, 'contrast') |
| self.saturation = self._check_input(saturation, 'saturation') |
| self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), |
| clip_first_on_zero=False) |
|
|
| def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): |
| if isinstance(value, numbers.Number): |
| if value < 0: |
| raise ValueError("If {} is a single number, it must be non negative.".format(name)) |
| value = [center - value, center + value] |
| if clip_first_on_zero: |
| value[0] = max(value[0], 0) |
| elif isinstance(value, (tuple, list)) and len(value) == 2: |
| if not bound[0] <= value[0] <= value[1] <= bound[1]: |
| raise ValueError("{} values should be between {}".format(name, bound)) |
| else: |
| raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) |
|
|
| |
| |
| if value[0] == value[1] == center: |
| value = None |
| return value |
|
|
| def adjust_contrast(self, x): |
| if self.contrast: |
| factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast) |
| means = torch.mean(x, dim=[2, 3], keepdim=True) |
| x = (x - means) * factor + means |
| return torch.clamp(x, 0, 1) |
|
|
| def adjust_hsv(self, x): |
| f_h = x.new_zeros(x.size(0), 1, 1) |
| f_s = x.new_ones(x.size(0), 1, 1) |
| f_v = x.new_ones(x.size(0), 1, 1) |
|
|
| if self.hue: |
| f_h.uniform_(*self.hue) |
| if self.saturation: |
| f_s = f_s.uniform_(*self.saturation) |
| if self.brightness: |
| f_v = f_v.uniform_(*self.brightness) |
|
|
| return RandomHSVFunction.apply(x, f_h, f_s, f_v) |
|
|
| def transform(self, inputs): |
| |
| if np.random.rand() > 0.5: |
| transforms = [self.adjust_contrast, self.adjust_hsv] |
| else: |
| transforms = [self.adjust_hsv, self.adjust_contrast] |
|
|
| for t in transforms: |
| inputs = t(inputs) |
|
|
| return inputs |
|
|
| def forward(self, inputs): |
| _prob = inputs.new_full((inputs.size(0),), self.prob) |
| _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) |
| return inputs * (1 - _mask) + self.transform(inputs) * _mask |
|
|
|
|
| class RandomHSVFunction(Function): |
| @staticmethod |
| def forward(ctx, x, f_h, f_s, f_v): |
| |
| |
| x = rgb2hsv(x) |
| h = x[:, 0, :, :] |
| h += (f_h * 255. / 360.) |
| h = (h % 1) |
| x[:, 0, :, :] = h |
| x[:, 1, :, :] = x[:, 1, :, :] * f_s |
| x[:, 2, :, :] = x[:, 2, :, :] * f_v |
| x = torch.clamp(x, 0, 1) |
| x = hsv2rgb(x) |
| return x |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| |
| |
| grad_input = None |
| if ctx.needs_input_grad[0]: |
| grad_input = grad_output.clone() |
| return grad_input, None, None, None |
|
|
|
|
| class NormalizeLayer(nn.Module): |
| """ |
| In order to certify radii in original coordinates rather than standardized coordinates, we |
| add the Gaussian noise _before_ standardizing, which is why we have standardization be the first |
| layer of the classifier rather than as a part of preprocessing as is typical. |
| """ |
|
|
| def __init__(self): |
| super(NormalizeLayer, self).__init__() |
|
|
| def forward(self, inputs): |
| return (inputs - 0.5) / 0.5 |
|
|
| import torch |
| from torch import Tensor |
| from torchvision.transforms.functional import to_pil_image, to_tensor |
| from torch.nn.functional import conv2d, pad as torch_pad |
| from typing import Any, List, Sequence, Optional |
| import numbers |
| import numpy as np |
| import torch |
| from PIL import Image |
| from typing import Tuple |
|
|
| class GaussianBlur(torch.nn.Module): |
| """Blurs image with randomly chosen Gaussian blur. |
| The image can be a PIL Image or a Tensor, in which case it is expected |
| to have [..., C, H, W] shape, where ... means an arbitrary number of leading |
| dimensions |
| Args: |
| kernel_size (int or sequence): Size of the Gaussian kernel. |
| sigma (float or tuple of float (min, max)): Standard deviation to be used for |
| creating kernel to perform blurring. If float, sigma is fixed. If it is tuple |
| of float (min, max), sigma is chosen uniformly at random to lie in the |
| given range. |
| Returns: |
| PIL Image or Tensor: Gaussian blurred version of the input image. |
| """ |
|
|
| def __init__(self, kernel_size, sigma=(0.1, 2.0)): |
| super().__init__() |
| self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") |
| for ks in self.kernel_size: |
| if ks <= 0 or ks % 2 == 0: |
| raise ValueError("Kernel size value should be an odd and positive number.") |
|
|
| if isinstance(sigma, numbers.Number): |
| if sigma <= 0: |
| raise ValueError("If sigma is a single number, it must be positive.") |
| sigma = (sigma, sigma) |
| elif isinstance(sigma, Sequence) and len(sigma) == 2: |
| if not 0. < sigma[0] <= sigma[1]: |
| raise ValueError("sigma values should be positive and of the form (min, max).") |
| else: |
| raise ValueError("sigma should be a single number or a list/tuple with length 2.") |
|
|
| self.sigma = sigma |
|
|
| @staticmethod |
| def get_params(sigma_min: float, sigma_max: float) -> float: |
| """Choose sigma for random gaussian blurring. |
| Args: |
| sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. |
| sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. |
| Returns: |
| float: Standard deviation to be passed to calculate kernel for gaussian blurring. |
| """ |
| return torch.empty(1).uniform_(sigma_min, sigma_max).item() |
|
|
| def forward(self, img: Tensor) -> Tensor: |
| """ |
| Args: |
| img (PIL Image or Tensor): image to be blurred. |
| Returns: |
| PIL Image or Tensor: Gaussian blurred image |
| """ |
| sigma = self.get_params(self.sigma[0], self.sigma[1]) |
| return gaussian_blur(img, self.kernel_size, [sigma, sigma]) |
|
|
| def __repr__(self): |
| s = '(kernel_size={}, '.format(self.kernel_size) |
| s += 'sigma={})'.format(self.sigma) |
| return self.__class__.__name__ + s |
|
|
| @torch.jit.unused |
| def _is_pil_image(img: Any) -> bool: |
| return isinstance(img, Image.Image) |
| def _setup_size(size, error_msg): |
| if isinstance(size, numbers.Number): |
| return int(size), int(size) |
|
|
| if isinstance(size, Sequence) and len(size) == 1: |
| return size[0], size[0] |
|
|
| if len(size) != 2: |
| raise ValueError(error_msg) |
|
|
| return size |
| def _is_tensor_a_torch_image(x: Tensor) -> bool: |
| return x.ndim >= 2 |
| def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: |
| ksize_half = (kernel_size - 1) * 0.5 |
|
|
| x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) |
| pdf = torch.exp(-0.5 * (x / sigma).pow(2)) |
| kernel1d = pdf / pdf.sum() |
|
|
| return kernel1d |
|
|
| def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]: |
| need_squeeze = False |
| |
| if img.ndim < 4: |
| img = img.unsqueeze(dim=0) |
| need_squeeze = True |
|
|
| out_dtype = img.dtype |
| need_cast = False |
| if out_dtype != req_dtype: |
| need_cast = True |
| img = img.to(req_dtype) |
| return img, need_cast, need_squeeze, out_dtype |
| def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype): |
| if need_squeeze: |
| img = img.squeeze(dim=0) |
|
|
| if need_cast: |
| |
| img = torch.round(img).to(out_dtype) |
|
|
| return img |
| def _get_gaussian_kernel2d( |
| kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device |
| ) -> Tensor: |
| kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) |
| kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) |
| kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) |
| return kernel2d |
| def _gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: |
| """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel. |
| .. warning:: |
| Module ``transforms.functional_tensor`` is private and should not be used in user application. |
| Please, consider instead using methods from `transforms.functional` module. |
| Args: |
| img (Tensor): Image to be blurred |
| kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``. |
| sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``. |
| Returns: |
| Tensor: An image that is blurred using gaussian kernel of given parameters |
| """ |
| if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)): |
| raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) |
|
|
| dtype = img.dtype if torch.is_floating_point(img) else torch.float32 |
| kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) |
| kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) |
|
|
| img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype) |
|
|
| |
| padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] |
| img = torch_pad(img, padding, mode="reflect") |
| img = conv2d(img, kernel, groups=img.shape[-3]) |
|
|
| img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) |
| return img |
|
|
| def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: |
| """Performs Gaussian blurring on the img by given kernel. |
| The image can be a PIL Image or a Tensor, in which case it is expected |
| to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions |
| Args: |
| img (PIL Image or Tensor): Image to be blurred |
| kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers |
| like ``(kx, ky)`` or a single integer for square kernels. |
| In torchscript mode kernel_size as single int is not supported, use a tuple or |
| list of length 1: ``[ksize, ]``. |
| sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a |
| sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the |
| same sigma in both X/Y directions. If None, then it is computed using |
| ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``. |
| Default, None. In torchscript mode sigma as single float is |
| not supported, use a tuple or list of length 1: ``[sigma, ]``. |
| Returns: |
| PIL Image or Tensor: Gaussian Blurred version of the image. |
| """ |
| if not isinstance(kernel_size, (int, list, tuple)): |
| raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size))) |
| if isinstance(kernel_size, int): |
| kernel_size = [kernel_size, kernel_size] |
| if len(kernel_size) != 2: |
| raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size))) |
| for ksize in kernel_size: |
| if ksize % 2 == 0 or ksize < 0: |
| raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size)) |
|
|
| if sigma is None: |
| sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] |
|
|
| if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): |
| raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma))) |
| if isinstance(sigma, (int, float)): |
| sigma = [float(sigma), float(sigma)] |
| if isinstance(sigma, (list, tuple)) and len(sigma) == 1: |
| sigma = [sigma[0], sigma[0]] |
| if len(sigma) != 2: |
| raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma))) |
| for s in sigma: |
| if s <= 0.: |
| raise ValueError('sigma should have positive values. Got {}'.format(sigma)) |
|
|
| t_img = img |
| if not isinstance(img, torch.Tensor): |
| if not _is_pil_image(img): |
| raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img))) |
|
|
| t_img = to_tensor(img) |
|
|
| output = _gaussian_blur(t_img, kernel_size, sigma) |
|
|
| if not isinstance(img, torch.Tensor): |
| output = to_pil_image(output) |
| return output |
|
|
| |
|
|
|
|
|
|
|
|
| def normalize(x, dim=1, eps=1e-8): |
| return x / (x.norm(dim=dim, keepdim=True) + eps) |
|
|
|
|
| def rot_inner_all(x): |
| num = x.shape[0] |
|
|
| image_size = x.shape[2] |
|
|
| R = x.repeat(4, 1, 1, 1) |
| a = x.permute(0, 1, 3, 2) |
|
|
| a = a.view(num, 3, 2, image_size//2, image_size) |
| a = a.permute(2, 0, 1, 3, 4) |
| s1 = a[0] |
| s2 = a[1] |
| s1_1 = torch.rot90(s1, 2, (2, 3)) |
| s2_2 = torch.rot90(s2, 2, (2, 3)) |
| R[num: 2 * num] = torch.cat((s1_1.unsqueeze(2), s2.unsqueeze(2)), dim=2).reshape(num,3, image_size, image_size).permute(0, 1, 3, 2) |
| R[3 * num:] = torch.cat((s1.unsqueeze(2), s2_2.unsqueeze(2)), dim=2).reshape(num,3, image_size, image_size).permute(0, 1, 3, 2) |
| R[2 * num: 3 * num] = torch.cat((s1_1.unsqueeze(2), s2_2.unsqueeze(2)), dim=2).reshape(num,3, image_size, image_size).permute(0, 1, 3, 2) |
| return R |
|
|
|
|
| def Rotation(x, y): |
| num = x.shape[0] |
| X = rot_inner_all(x) |
| y = y.repeat(16) |
| for i in range(1, 16): |
| y[i * num:(i + 1) * num]+=1000 * i |
| return torch.cat((X, torch.rot90(X, 1, (2, 3)), torch.rot90(X, 2, (2, 3)), torch.rot90(X, 3, (2, 3))), dim=0), y |
|
|
|
|
|
|
|
|
|
|
| def get_similarity_matrix(outputs, chunk=2, multi_gpu=False): |
| ''' |
| Compute similarity matrix |
| - outputs: (B', d) tensor for B' = B * chunk |
| - sim_matrix: (B', B') tensor |
| |
| Code Reference: |
| https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| ''' |
| if multi_gpu: |
| outputs_gathered = [] |
| for out in outputs.chunk(chunk): |
| gather_t = [torch.empty_like(out) for _ in range(dist.get_world_size())] |
| gather_t = torch.cat(distops.all_gather(gather_t, out)) |
| outputs_gathered.append(gather_t) |
| outputs = torch.cat(outputs_gathered) |
| sim_matrix = torch.mm(outputs, outputs.t()) |
|
|
| return sim_matrix |
|
|
|
|
| def Supervised_NT_xent_n(sim_matrix, labels, embedding=None,temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): |
| ''' |
| Compute NT_xent loss |
| - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) |
| |
| Code Reference: |
| https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| ''' |
| device = sim_matrix.device |
| labels1 = labels.repeat(2) |
| logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) |
| sim_matrix = sim_matrix - logits_max.detach() |
| B = sim_matrix.size(0) // chunk |
| eye = torch.eye(B * chunk).to(device) |
| sim_matrix = torch.exp(sim_matrix / temperature) * (1 - eye) |
| denom = torch.sum(sim_matrix, dim=1, keepdim=True) |
| sim_matrix = -torch.log(sim_matrix/(denom + eps) + eps) |
| labels1 = labels1.contiguous().view(-1, 1) |
| Mask1 = torch.eq(labels1, labels1.t()).float().to(device) |
| Mask1 = Mask1 / (Mask1.sum(dim=1, keepdim=True) + eps) |
| loss1 = 2 * torch.sum(Mask1 * sim_matrix) / (2 * B) |
| return (torch.sum(sim_matrix[:B, B:].diag() + sim_matrix[B:, :B].diag()) / (2 * B)) + loss1 |
|
|
|
|
| def Supervised_NT_xent_uni(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): |
| ''' |
| Compute NT_xent loss |
| - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) |
| |
| Code Reference: |
| https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| ''' |
| device = sim_matrix.device |
| labels1 = labels.repeat(2) |
| logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) |
| sim_matrix = sim_matrix - logits_max.detach() |
| B = sim_matrix.size(0) // chunk |
| sim_matrix = torch.exp(sim_matrix / temperature) |
| denom = torch.sum(sim_matrix, dim=1, keepdim=True) |
| sim_matrix = - torch.log(sim_matrix / (denom + eps) + eps) |
| labels1 = labels1.contiguous().view(-1, 1) |
| Mask1 = torch.eq(labels1, labels1.t()).float().to(device) |
| Mask1 = Mask1 / (Mask1.sum(dim=1, keepdim=True) + eps) |
| return torch.sum(Mask1 * sim_matrix) / (2 * B) |
|
|
|
|
|
|
|
|
|
|
| def Supervised_NT_xent_pre(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): |
| ''' |
| Compute NT_xent loss |
| - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) |
| |
| Code Reference: |
| https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| ''' |
| device = sim_matrix.device |
| labels1 = labels |
| logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) |
| sim_matrix = sim_matrix - logits_max.detach() |
| B = sim_matrix.size(0) // chunk |
| sim_matrix = torch.exp(sim_matrix / temperature) |
| denom = torch.sum(sim_matrix, dim=1, keepdim=True) |
| sim_matrix = -torch.log(sim_matrix/(denom+eps)+eps) |
| labels1 = labels1.contiguous().view(-1, 1) |
| Mask1 = torch.eq(labels1, labels1.t()).float().to(device) |
| Mask1 = Mask1 / (Mask1.sum(dim=1, keepdim=True) + eps) |
| return torch.sum(Mask1 * sim_matrix) / (2 * B) |
|
|
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
|
|
| class OCM_Model(nn.Module): |
|
|
| def __init__(self, backbone, feat_dim, num_class, device): |
| ''' |
| A OCM model consists of a backbone, a classifier and a self-supervised head |
| ''' |
| |
| super(OCM_Model, self).__init__() |
| self.backbone = backbone |
| self.classifier = nn.Linear(feat_dim, num_class) |
| self.head = nn.Linear(feat_dim, 128) |
| self.device = device |
|
|
| def get_features(self, x): |
| out = self.backbone(x)['features'] |
| return out |
| |
|
|
| def forward_head(self, x): |
| feat = self.get_features(x) |
| out = self.head(feat) |
| return feat, out |
|
|
|
|
| def forward_classifier(self, x): |
| feat = self.get_features(x) |
| logits = self.classifier(feat) |
| return logits |
|
|
| class OCM(nn.Module): |
|
|
| def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| super(OCM, self).__init__() |
| |
| |
| self.device = kwargs['device'] |
| |
| |
| self.cur_task_id = 0 |
|
|
| |
| |
| |
| |
| self.model = OCM_Model(backbone, feat_dim, num_class, self.device) |
| |
| |
| self.previous_model = None |
|
|
| |
| self.class_holder = [] |
|
|
| self.buffer_per_class = 7 |
|
|
|
|
| self.init_cls_num = kwargs['init_cls_num'] |
| self.inc_cls_num = kwargs['inc_cls_num'] |
| self.task_num = kwargs['task_num'] |
| self.image_size = kwargs['image_size'] |
|
|
| self.simclr_aug = torch.nn.Sequential( |
| HorizontalFlipLayer().to(self.device), |
| RandomColorGrayLayer(p=0.25).to(self.device), |
| RandomResizedCropLayer(scale=(0.3, 1.0), size=[self.image_size, self.image_size, 3]).to(self.device) |
| ) |
| |
| def observe(self, data): |
| |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
|
|
| |
| Y = deepcopy(y) |
| for j in range(len(Y)): |
| if Y[j] not in self.class_holder: |
| self.class_holder.append(Y[j].detach()) |
|
|
|
|
| |
| x = x.requires_grad_() |
|
|
| if self.cur_task_id == 0: |
| pred, acc, loss = self.observe_first_task(x, y) |
| else: |
| pred, acc, loss = self.observe_incremental_tasks(x, y) |
|
|
| |
| self.buffer.add_reservoir(x=x.detach(), y=y.detach(), task=self.cur_task_id) |
|
|
| return pred, acc, loss |
| |
|
|
|
|
| def observe_first_task(self, x, y): |
| """ |
| Code Reference: |
| https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| """ |
| images1, rot_sim_labels = Rotation(x, y) |
| images_pair = torch.cat([images1, self.simclr_aug(images1)], dim=0) |
| rot_sim_labels = rot_sim_labels.cuda() |
| feature_map,outputs_aux = self.model.forward_head(images_pair) |
| simclr = normalize(outputs_aux) |
| feature_map_out = normalize(feature_map[:images_pair.shape[0]]) |
| num1 = feature_map_out.shape[1] - simclr.shape[1] |
| id1 = torch.randperm(num1)[0] |
| size = simclr.shape[1] |
| sim_matrix = torch.matmul(simclr, feature_map_out[:, id1 :id1+ 1 * size].t()) |
| sim_matrix += get_similarity_matrix(simclr) |
| loss_sim1 = Supervised_NT_xent_n(sim_matrix, labels=rot_sim_labels, temperature=0.07) |
| lo1 = loss_sim1 |
| y_pred = self.model.forward_classifier(self.simclr_aug(x)) |
| loss = F.cross_entropy(y_pred, y) + lo1 |
| pred = torch.argmin(y_pred, dim=1) |
| acc = torch.sum(pred == y).item() / x.size(0) |
| |
| return y_pred, acc, loss |
|
|
|
|
| |
| def observe_incremental_tasks(self, x, y): |
| """ |
| Code Reference: |
| https://github.com/gydpku/OCM/blob/main/test_cifar10.py |
| """ |
| buffer_batch_size = min(64, self.buffer_per_class*len(self.class_holder)) |
| mem_x, mem_y,_ = self.buffer.sample(buffer_batch_size, exclude_task=None) |
| mem_x = mem_x.requires_grad_() |
| images1, rot_sim_labels = Rotation(x, y) |
| images1_r, rot_sim_labels_r = Rotation(mem_x, |
| mem_y) |
| images_pair = torch.cat([images1, self.simclr_aug(images1)], dim=0) |
| images_pair_r = torch.cat([images1_r, self.simclr_aug(images1_r)], dim=0) |
| t = torch.cat((images_pair,images_pair_r),dim=0) |
| feature_map, u = self.model.forward_head(t) |
| pre_u_feature, pre_u = self.previous_model.forward_head(images1_r) |
| feature_map_out = normalize(feature_map[:images_pair.shape[0]]) |
| feature_map_out_r = normalize(feature_map[images_pair.shape[0]:]) |
| images_out = u[:images_pair.shape[0]] |
| images_out_r = u[images_pair.shape[0]:] |
| pre_u = normalize(pre_u) |
| simclr = normalize(images_out) |
| simclr_r = normalize(images_out_r) |
| num1 = feature_map_out.shape[1] - simclr.shape[1] |
| id1 = torch.randperm(num1)[0] |
| id2 = torch.randperm(num1)[0] |
| size = simclr.shape[1] |
|
|
| sim_matrix = torch.matmul(simclr, feature_map_out[:, id1:id1 + size].t()) |
| sim_matrix_r = torch.matmul(simclr_r, feature_map_out_r[:, id2:id2 + size].t()) |
| sim_matrix += get_similarity_matrix(simclr) |
| sim_matrix_r += get_similarity_matrix(simclr_r) |
| sim_matrix_r_pre = torch.matmul(simclr_r[:images1_r.shape[0]],pre_u.t()) |
| loss_sim_r =Supervised_NT_xent_uni(sim_matrix_r,labels=rot_sim_labels_r,temperature=0.07) |
| loss_sim_pre = Supervised_NT_xent_pre(sim_matrix_r_pre, labels=rot_sim_labels_r, temperature=0.07) |
| loss_sim = Supervised_NT_xent_n(sim_matrix, labels=rot_sim_labels, temperature=0.07) |
| lo1 = loss_sim_r + loss_sim + loss_sim_pre |
| y_label = self.model.forward_classifier(self.simclr_aug(mem_x)) |
| y_label_pre = self.previous_model.forward_classifier(self.simclr_aug(mem_x)) |
| loss = F.cross_entropy(y_label, mem_y) + lo1 + F.mse_loss(y_label_pre[:, :self.prev_cls_num], |
| y_label[:, |
| :self.prev_cls_num]) |
| |
| with torch.no_grad(): |
| logits = self.model.forward_classifier(x)[:, :self.accu_cls_num] |
| pred = torch.argmax(logits, dim=1) |
| acc = torch.sum(pred == y).item() / x.size(0) |
| return logits, acc, loss |
|
|
|
|
|
|
|
|
| def inference(self, data): |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
| logits = self.model.forward_classifier(x) |
| pred = torch.argmax(logits, dim=1) |
| acc = torch.sum(pred == y).item() |
| return pred, acc / x.size(0) |
| |
|
|
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
| if self.cur_task_id == 0: |
| self.buffer = buffer |
|
|
| if self.cur_task_id == 0: |
| self.accu_cls_num = self.init_cls_num |
| else: |
| self.accu_cls_num += self.inc_cls_num |
|
|
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| self.prev_cls_num = self.accu_cls_num |
| self.cur_task_id += 1 |
| self.previous_model = deepcopy(self.model) |
|
|
|
|
| def get_parameters(self, config): |
| return self.model.parameters() |