LibContinual / core /model /dualprompt.py
boringKey's picture
Upload 236 files
5fee096 verified
# -*- coding: utf-8 -*-
"""
@inproceedings{DBLP:conf/eccv/0002ZESZLRSPDP22,
author = {Zifeng Wang and
Zizhao Zhang and
Sayna Ebrahimi and
Ruoxi Sun and
Han Zhang and
Chen{-}Yu Lee and
Xiaoqi Ren and
Guolong Su and
Vincent Perot and
Jennifer G. Dy and
Tomas Pfister},
editor = {Shai Avidan and
Gabriel J. Brostow and
Moustapha Ciss{\'{e}} and
Giovanni Maria Farinella and
Tal Hassner},
title = {DualPrompt: Complementary Prompting for Rehearsal-Free Continual Learning},
booktitle = {Computer Vision - {ECCV} 2022 - 17th European Conference, Tel Aviv,
Israel, October 23-27, 2022, Proceedings, Part {XXVI}},
volume = {13686},
pages = {631--648},
publisher = {Springer},
year = {2022}
}
https://arxiv.org/pdf/2204.04799
Adapted from https://github.com/GT-RIPL/CODA-Prompt
"""
import math
import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from .finetune import Finetune
from core.model.backbone.resnet import *
import numpy as np
from torch.utils.data import DataLoader
class Model(nn.Module):
# A model consists with a backbone and a classifier
def __init__(self, backbone, feat_dim, num_class):
super().__init__()
self.backbone = backbone
self.feat_dim = feat_dim
self.num_class = num_class
self.classifier = nn.Linear(feat_dim, num_class)
def forward(self, x, train=True):
if train:
feat, loss = self.backbone(x, train=True)
return self.classifier(feat), loss
else:
feat = self.backbone(x, train=False)
return self.classifier(feat)
class DualPrompt(Finetune):
def __init__(self, backbone, feat_dim, num_class, **kwargs):
super().__init__(backbone, feat_dim, num_class, **kwargs)
self.kwargs = kwargs
self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num'])
self.network.backbone.create_prompt('dual', n_tasks = kwargs['task_num'], prompt_param=[10, kwargs['e_prompt_length'], kwargs['g_prompt_length']])
self.task_idx = 0
self.kwargs = kwargs
self.last_out_dim = 0
def before_task(self, task_idx, buffer, train_loader, test_loaders):
self.task_idx = task_idx
self.network.backbone.task_id = task_idx
in_features = self.network.classifier.in_features
out_features = self.network.classifier.out_features
new_out_features = self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num']
new_fc = nn.Linear(in_features, new_out_features)
new_fc.weight.data[:out_features] = self.network.classifier.weight.data
new_fc.bias.data[:out_features] = self.network.classifier.bias.data
self.network.classifier = new_fc
self.network.to(self.device)
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
self.out_dim = new_out_features
self.dw_k = torch.tensor(np.ones(self.out_dim + 1, dtype=np.float32)).to(self.device)
def observe(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit, loss = self.network(x, train=True)
logit[:,:self.last_out_dim] = -float('inf')
dw_cls = self.dw_k[-1 * torch.ones(y.size()).long()]
loss += (self.loss_fn(logit, y) * dw_cls).mean()
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0), loss
def after_task(self, task_idx, buffer, train_loader, test_loaders):
self.last_out_dim = self.out_dim
def inference(self, data):
x, y = data['image'], data['label']
x = x.to(self.device)
y = y.to(self.device)
logit = self.network(x, train=False)
pred = torch.argmax(logit, dim=1)
acc = torch.sum(pred == y).item()
return pred, acc / x.size(0)
def get_parameters(self, config):
return list(self.network.backbone.prompt.parameters()) + list(self.network.classifier.parameters())