| ''' |
| @misc{mcdonnell2024ranpacrandomprojectionspretrained, |
| title={RanPAC: Random Projections and Pre-trained Models for Continual Learning}, |
| author={Mark D. McDonnell and Dong Gong and Amin Parveneh and Ehsan Abbasnejad and Anton van den Hengel}, |
| year={2024}, |
| eprint={2307.02251}, |
| archivePrefix={arXiv}, |
| primaryClass={cs.LG}, |
| url={https://arxiv.org/abs/2307.02251}, |
| } |
| |
| Code Reference: |
| https://github.com/RanPAC/RanPAC |
| ''' |
|
|
| import copy |
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .backbone.transformer import MultiHeadAttention_LoRA, VisionTransformer |
| from .backbone.clip import CLIP, tokenize |
| from .backbone.vit import ViTZoo, ViT_in21k_adapter |
|
|
| VIT = ViT_in21k_adapter |
| CLIP = CLIP |
|
|
| class CosineLinear(nn.Module): |
| def __init__(self, in_features, out_features): |
|
|
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) |
| self.sigma = nn.Parameter(torch.Tensor(1)) |
| self.reset_parameters() |
|
|
| self.use_RP = False |
| self.W_rand = None |
|
|
| def reset_parameters(self): |
|
|
| stdv = 1. / math.sqrt(self.weight.size(1)) |
| self.weight.data.uniform_(-stdv, stdv) |
| self.sigma.data.fill_(1) |
|
|
| def forward(self, input): |
|
|
| if not self.use_RP: |
| out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) |
| else: |
| if self.W_rand is not None: |
| inn = F.relu(input @ self.W_rand) |
| else: |
| assert 0, 'should not reach here, for now' |
| inn = input |
| out = F.linear(inn, self.weight) |
|
|
| out = self.sigma * out |
|
|
| return out |
|
|
| class Network(nn.Module): |
| def __init__(self, backbone, device, **kwargs): |
| super().__init__() |
|
|
| self._cur_task_id = -1 |
| self.backbone = backbone |
| self.device = device |
| self.classifier = None |
|
|
| if isinstance(self.backbone, VIT): |
| self.feature_dim = self.backbone.feat_dim |
| elif isinstance(self.backbone, CLIP): |
| |
| self.feature_dim = self.backbone.visual.output_dim + self.backbone.transformer.width |
| self.accm_class_names = [] |
| self.curr_class_names = [] |
| self.accm_text_tokens = None |
| self.curr_text_tokens = None |
|
|
| self.prompt_template = kwargs['prompt_template'] |
|
|
| def update_classifer(self, num_classes, train_loader): |
|
|
| if isinstance(self.backbone, VIT): |
| pass |
| elif isinstance(self.backbone, CLIP): |
| self.curr_class_names = train_loader.dataset.get_class_names() |
| self.accm_class_names += self.curr_class_names |
|
|
| self.curr_text_tokens = tokenize( |
| [self.prompt_template.format(c) for c in self.curr_class_names] |
| ).to(self.device) |
|
|
| self.accm_text_tokens = tokenize( |
| [self.prompt_template.format(c) for c in self.accm_class_names] |
| ).to(self.device) |
| else: |
| assert 0 |
|
|
| self._cur_task_id += 1 |
| del self.classifier |
| self.classifier = CosineLinear(self.feature_dim, num_classes).to(self.device) |
|
|
| def get_feature(self, x): |
|
|
| if isinstance(self.backbone, VIT): |
| return self.backbone(x) |
| elif isinstance(self.backbone, CLIP): |
| features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens) |
|
|
| max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) |
| max_features = features_text[max_indices] |
|
|
| return torch.cat([features_image, max_features], dim=1) |
| else: |
| assert 0 |
|
|
| def forward(self, x, inference=False): |
|
|
| if isinstance(self.backbone, VIT): |
| features = self.backbone(x) |
| elif isinstance(self.backbone, CLIP): |
| if inference: |
| features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.accm_text_tokens) |
| else: |
| features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens) |
|
|
| max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) |
| max_features = features_text[max_indices] |
| features = torch.cat([features_image, max_features], dim=1) |
| else: |
| assert 0 |
|
|
| return self.classifier(features) |
|
|
| class RanPAC(nn.Module): |
| def __init__(self, backbone, device, **kwargs): |
| super().__init__() |
|
|
| self._network = Network(backbone, device, **kwargs) |
|
|
| self.device = device |
| self.first_session_training = kwargs["first_session_training"] |
| self.init_cls_num = kwargs["init_cls_num"] |
| self.inc_cls_num = kwargs["inc_cls_num"] |
| self.total_cls_num = kwargs['total_cls_num'] |
| self.task_num = kwargs["task_num"] |
| |
| self.M = kwargs['M'] |
|
|
| self._known_classes = 0 |
| self._classes_seen_so_far = 0 |
| self._skip_train = False |
|
|
| self._network.to(self.device) |
|
|
| if isinstance(backbone, CLIP): |
| for name, param in self._network.named_parameters(): |
| if 'adapt' not in name: |
| param.requires_grad = False |
|
|
|
|
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
|
|
| if task_idx == 0: |
| self._classes_seen_so_far = self.init_cls_num |
| elif task_idx > 0: |
| self._classes_seen_so_far += self.inc_cls_num |
| |
| self._network.update_classifer(self._classes_seen_so_far, train_loader) |
|
|
| if task_idx == 0 and self.first_session_training: |
| self._skip_train = False |
| else: |
| self._skip_train = True |
| print(f"Not training on task {task_idx}") |
|
|
| def observe(self, data): |
|
|
| if self._skip_train: |
| |
| return None, 0., torch.tensor(0., device = self.device, requires_grad = True) |
|
|
| inputs, targets = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes |
|
|
| logits = self._network(inputs) |
| loss = F.cross_entropy(logits, targets) |
|
|
| _, preds = torch.max(logits, dim=1) |
| correct = preds.eq(targets.expand_as(preds)).sum().item() |
| total = len(targets) |
|
|
| acc = round(correct / total, 4) |
|
|
| return preds, acc, loss |
|
|
| def inference(self, data): |
|
|
| inputs, targets = data['image'].to(self.device), data['label'] |
| logits = self._network(inputs, True) |
| _, preds = torch.max(logits, dim=1) |
|
|
| correct = preds.cpu().eq(targets.expand_as(preds)).sum().item() |
| total = len(targets) |
|
|
| acc = round(correct / total, 4) |
|
|
| return logits, acc |
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
|
|
| self._known_classes = self._classes_seen_so_far |
|
|
| if task_idx == 0: |
| |
| |
| self.W_rand = torch.randn(self._network.classifier.in_features, self.M) |
| self.Q = torch.zeros(self.M, self.init_cls_num) |
| self.G = torch.zeros(self.M, self.M) |
|
|
| else: |
| self.Q = torch.cat((self.Q, torch.zeros(self.M, self.inc_cls_num)), dim=1) |
|
|
| self.update_rp_classifier(train_loader, test_loaders[0].dataset.trfms) |
|
|
| @torch.no_grad() |
| def update_rp_classifier(self, train_loader, test_trfms): |
|
|
| self._network.eval() |
| train_loader.dataset.trfms = test_trfms |
|
|
| self._network.classifier.use_RP = True |
| self._network.classifier.W_rand = self.W_rand.to(self.device) |
|
|
| feature_list, label_list = [], [] |
| for batch in train_loader: |
| x, y = batch['image'].to(self.device), batch['label'] |
| feature_list.append(self._network.get_feature(x).cpu()) |
| label_list.append(y) |
| feature_list, label_list = torch.cat(feature_list, dim = 0), torch.cat(label_list, dim = 0) |
| |
| label_list = F.one_hot(label_list, self._classes_seen_so_far).to(torch.float32) |
| |
| proj_feature_list = F.relu(feature_list @ self.W_rand) |
|
|
| self.Q += proj_feature_list.T @ label_list |
| self.G += proj_feature_list.T @ proj_feature_list |
| |
| ridges = 10.0**np.arange(-8,9) |
| num_val_samples = int(proj_feature_list.shape[0] * 0.8) |
| losses = [] |
| Q_val = proj_feature_list[:num_val_samples, :].T @ label_list[:num_val_samples, :] |
| G_val = proj_feature_list[:num_val_samples, :].T @ proj_feature_list[:num_val_samples, :] |
| for ridge in ridges: |
| Wo = torch.linalg.solve(G_val + ridge * torch.eye(self.M), Q_val).T |
| Y_train_pred = proj_feature_list[num_val_samples:, :] @ Wo.T |
| losses.append(F.mse_loss(Y_train_pred, label_list[num_val_samples:, :])) |
| ridge = ridges[np.argmin(np.array(losses))] |
| print(f"Optimal lambda: {ridge}") |
|
|
| Wo = torch.linalg.solve(self.G + ridge * torch.eye(self.M), self.Q).T |
| self._network.classifier.weight.data = Wo[:self._network.classifier.weight.shape[0], :].to(self.device) |
|
|
| def get_parameters(self, config): |
| return self._network.parameters() |