| |
| """ |
| @inproceedings{DBLP:conf/cvpr/0002ZL0SRSPDP22, |
| author = {Zifeng Wang and |
| Zizhao Zhang and |
| Chen{-}Yu Lee and |
| Han Zhang and |
| Ruoxi Sun and |
| Xiaoqi Ren and |
| Guolong Su and |
| Vincent Perot and |
| Jennifer G. Dy and |
| Tomas Pfister}, |
| title = {Learning to Prompt for Continual Learning}, |
| booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, |
| {CVPR} 2022, New Orleans, LA, USA, June 18-24, 2022}, |
| pages = {139--149}, |
| publisher = {{IEEE}}, |
| year = {2022} |
| } |
| |
| https://arxiv.org/abs/2112.08654 |
| |
| Adapted from https://github.com/GT-RIPL/CODA-Prompt |
| """ |
|
|
| import math |
| import copy |
| import torch |
| import numpy as np |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from core.model.backbone.resnet import * |
|
|
| class Model(nn.Module): |
| def __init__(self, backbone, embed_dim, total_cls_num): |
| super().__init__() |
| self.backbone = backbone |
| self.classifier = nn.Linear(embed_dim, total_cls_num, bias=True) |
|
|
| def forward(self, x, train=True): |
| feat, reduce_sim = self.backbone(x, train=train) |
| return self.classifier(feat), reduce_sim |
|
|
| class L2P(nn.Module): |
| def __init__(self, backbone, device, **kwargs): |
| super().__init__() |
|
|
| self.device = device |
| self.init_cls_num = kwargs['init_cls_num'] |
| self.inc_cls_num = kwargs['inc_cls_num'] |
| self.total_cls_num = kwargs['num_class'] |
| self.task_num = kwargs['task_num'] |
| self.embed_dim = kwargs['feat_dim'] |
| self.pull_constraint_coeff = kwargs['pull_constraint_coeff'] |
| self.cur_task_id = 0 |
| self._known_classes = 0 |
| |
| self.network = Model(backbone, self.embed_dim, self.total_cls_num) |
| self.network.backbone.create_prompt( |
| prompt_flag = 'l2p', |
| length = kwargs['prompt_length'], |
| prompt_init = nn.init.uniform_, |
| pool_size = kwargs['pool_size'], |
| top_k = kwargs['top_k'], |
| num_layers = 1, |
| embed_dim = self.embed_dim |
| ) |
| self.network.to(self.device) |
|
|
| self.unfrezeed_params = [] |
| for name, param in self.network.named_parameters(): |
| param.requires_grad_(False) |
| if 'prompt' in name or 'classifier' in name: |
| param.requires_grad_(True) |
| self.unfrezeed_params.append(param) |
|
|
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
|
|
| self.cur_task_id = task_idx |
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| |
| self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num |
|
|
| def observe(self, data): |
|
|
| x, y = data['image'].to(self.device), data['label'].to(self.device) |
| logits, reduce_sim = self.network(x, train=True) |
|
|
| if self.cur_task_id == 0: |
| mask = np.arange(self.init_cls_num) |
| else: |
| mask = np.arange(self.inc_cls_num) + self._known_classes |
|
|
| not_mask = np.setdiff1d(np.arange(self.total_cls_num), mask) |
| not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device) |
| logits = logits.index_fill(dim=1, index=not_mask, value=float('-inf')) |
|
|
| loss = F.cross_entropy(logits, y) - self.pull_constraint_coeff * reduce_sim |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.unfrezeed_params, 1.0) |
|
|
| pred = torch.argmax(logits, dim=1) |
| acc = torch.sum(pred == y).item() / x.size(0) |
|
|
| return pred, acc, loss |
|
|
| def inference(self, data): |
| |
| x, y = data['image'].to(self.device), data['label'].to(self.device) |
| logits, _ = self.network(x, train=False) |
|
|
| pred = torch.argmax(logits, dim=1) |
| acc = torch.sum(pred == y).item() / x.size(0) |
| return pred, acc |
|
|
| def get_parameters(self, config): |
|
|
| return self.unfrezeed_params |
|
|