| |
| """ |
| TODO: citation |
| |
| Adapted from TODO: source |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| from torch import optim |
| from torch.nn import functional as F |
| from torch.nn.parameter import Parameter |
| from tqdm import tqdm |
|
|
| from .backbone.transformer import ResidualAttentionBlock |
| from .backbone.clip import tokenize, CLIP |
| from .backbone.vit import ViTZoo |
|
|
| VIT = ViTZoo |
| CLIP = CLIP |
|
|
| class DMNSP(nn.Module): |
|
|
| def __init__(self, backbone, device, **kwargs): |
| super().__init__() |
|
|
| self.device = device |
| self.init_cls_num = kwargs['init_cls_num'] |
| self.inc_cls_num = kwargs['inc_cls_num'] |
| self.label_smoothing = kwargs['label_smoothing'] |
|
|
| self._cur_task_id = -1 |
| self._known_classes = 0 |
| self.visual_U = [] |
| self.lamda = [[0 for _ in range(12)] for _ in range(12)] |
| self.lamda_scale = kwargs['lamda_scale'] |
|
|
| self.accm_class_names = [] |
| self.curr_class_names = [] |
| self.accm_text_tokens = None |
| self.curr_text_tokens = None |
|
|
| self.prompt_template = kwargs['prompt_template'] |
| |
| self._network = backbone |
|
|
| for name, param in self._network.named_parameters(): |
| if 'adapt' not in name: |
| param.requires_grad = False |
|
|
| if isinstance(self._network, VIT): |
| self.visual_transformer_blocks = [module for module in self._network.modules() if isinstance(module, ResidualAttentionBlock)] |
|
|
| self.classifier_pool = nn.ModuleList([ |
| nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + |
| [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)] |
| ) |
|
|
| elif isinstance(self._network, CLIP): |
| self.visual_transformer_blocks = [module for name, module in self._network.named_modules() if isinstance(module, ResidualAttentionBlock) and 'visual' in name] |
| else: |
| assert 0 |
|
|
| def observe(self, data): |
|
|
| x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes |
|
|
| if isinstance(self._network, CLIP): |
| features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.curr_text_tokens) |
| elif isinstance(self._network, ViTZoo): |
| features = self._network(x) |
| logits_per_img = [] |
| for prompts in [self.classifier_pool[self._cur_task_id]]: |
| logits_per_img.append(prompts(features)) |
| logits_per_img = torch.cat(logits_per_img, dim=1) |
|
|
| loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing) |
|
|
| preds = logits_per_img.softmax(dim=-1).argmax(dim=1) |
| acc = preds.eq(y).sum().item() / y.size(0) |
|
|
| loss.backward() |
|
|
| if self._cur_task_id > 0: |
|
|
| if isinstance(self._network, VIT): |
|
|
| for name, param in self._network.named_parameters(): |
| for i in range(12): |
| if 'adapt' in name and 'down' in name and 'weight' in name: |
|
|
| v = self.visual_U[i].to(self.device) |
| v_ = torch.mm(param.grad.data, v) |
| param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i] |
|
|
| elif 'adapt' in name and 'up' in name and 'weight' in name: |
|
|
| v = self.visual_U[i].to(self.device) |
| v_ = torch.mm(v.T, param.grad.data) |
| param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i] |
|
|
| elif isinstance(self._network, CLIP): |
|
|
| for name, param in self._network.named_parameters(): |
| for i in range(12): |
| if 'visual' in name and 'adapt' in name and 'down' in name and 'weight' in name: |
|
|
| v = self.visual_U[i].to(self.device) |
| v_ = torch.mm(param.grad.data, v) |
| param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i] |
|
|
| elif 'visual' in name and 'adapt' in name and 'up' in name and 'weight' in name: |
|
|
| v = self.visual_U[i].to(self.device) |
| v_ = torch.mm(v.T, param.grad.data) |
| param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i] |
|
|
|
|
| return preds, acc, loss |
|
|
| def inference(self, data, task_id = -1): |
|
|
| x, y = data['image'].to(self.device), data['label'].to(self.device) |
|
|
| if isinstance(self._network, CLIP): |
| if task_id > -1: |
| assert self.init_cls_num == self.inc_cls_num |
| features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens[task_id * self.inc_cls_num : (task_id + 1) * self.inc_cls_num]) |
| else: |
| features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens) |
| elif isinstance(self._network, VIT): |
| if task_id > -1: |
| assert 0, 'Not Implemented' |
| else: |
| features = self._network(x) |
| logits_per_img = [] |
| for prompts in self.classifier_pool[:self._cur_task_id + 1]: |
| logits_per_img.append(prompts(features)) |
| logits_per_img = torch.cat(logits_per_img, dim=1) |
|
|
| preds = logits_per_img.softmax(dim=-1).argmax(dim=1) |
|
|
| if task_id > -1: |
| assert self.init_cls_num == self.inc_cls_num |
| preds += task_id * self.inc_cls_num |
|
|
| acc = preds.eq(y).sum().item() / y.size(0) |
|
|
| return preds, acc |
| |
| @torch.no_grad() |
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
| self._cur_task_id = task_idx |
| if task_idx == 1: |
| self._known_classes = self.init_cls_num |
| elif task_idx > 1: |
| self._known_classes += self.inc_cls_num |
|
|
| self.curr_class_names = train_loader.dataset.get_class_names() |
| self.accm_class_names += self.curr_class_names |
|
|
| self.curr_text_tokens = tokenize( |
| [self.prompt_template.format(c) for c in self.curr_class_names] |
| ).to(self.device) |
|
|
| self.accm_text_tokens = tokenize( |
| [self.prompt_template.format(c) for c in self.accm_class_names] |
| ).to(self.device) |
|
|
| if task_idx > 0: |
| for data in train_loader: |
| x = data['image'].to(self.device) |
| self._network(x, self.curr_text_tokens, compute_lora_feat=True) |
|
|
| for j in range(12): |
| activation_visual = self.visual_transformer_blocks[j].lora_feature |
| activation_visual = torch.bmm(activation_visual.permute(1, 2, 0), |
| activation_visual.permute(1, 0, 2)).sum(dim=0) |
| U_visual, _, _ = torch.linalg.svd(activation_visual, full_matrices=False) |
| U_visual = U_visual[:, 0:1] |
|
|
| for k in range(12): |
| v_visual = self.visual_U[k] |
| normalized_vector_visual = U_visual / torch.norm(U_visual) |
| similarities_visual = [] |
|
|
| for column_visual in v_visual.t(): |
| normalized_column_visual = column_visual / torch.norm(column_visual) |
| cos_sim_visual = torch.dot(normalized_vector_visual.squeeze(), |
| normalized_column_visual.squeeze()) |
| similarities_visual.append(cos_sim_visual) |
|
|
| dot_products_visual = torch.mean(torch.topk(torch.stack(similarities_visual), int(len(similarities_visual) * 00.1))[0]) |
| self.lamda[j][k] = torch.exp(-dot_products_visual) * self.lamda_scale |
|
|
| break |
|
|
| @torch.no_grad() |
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
|
|
| for data in train_loader: |
| x = data['image'].to(self.device) |
| self._network(x, self.curr_text_tokens, compute_lora_feat=True) |
|
|
| for i in range(12): |
|
|
| activation = self.visual_transformer_blocks[i].lora_feature |
| |
| activation = torch.bmm(activation.permute(1, 2, 0), |
| activation.permute(1, 0, 2)).sum(dim=0) |
|
|
| U, _, _ = torch.linalg.svd(activation, full_matrices=False) |
|
|
| if task_idx == 0: |
| r = 0 |
| self.visual_U.append(U[:,max(r,1):]) |
| else: |
| r = 1 |
| Ui = torch.cat((self.visual_U[i], U[:, r:]), dim=1) |
| self.visual_U[i] = Ui |
|
|
| break |
|
|
| def get_parameters(self, config): |
| return self._network.parameters() |