| """ |
| @inproceedings{10.24963/ijcai.2024/456, |
| author = {Hong, Chenxing and Jin, Yan and Kang, Zhiqi and Chen, Yizhou and Li, Mengke and Lu, Yang and Wang, Hanzi}, |
| title = {Dynamically anchored prompting for task-imbalanced continual learning}, |
| booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence}, |
| year = {2025}, |
| } |
| https://dl.acm.org/doi/10.24963/ijcai.2024/456 |
| Adapted from https://github.com/chenxing6666/dap |
| """ |
|
|
| import math |
| import copy |
| import torch |
| import torch.nn.functional as F |
| from .finetune import Finetune |
| import numpy as np |
| from torch.utils.data import DataLoader |
|
|
| global_max_dist = torch.tensor(0) |
| global_max_dist2 = torch.tensor(0) |
| global_lam = 0.25 |
|
|
|
|
| class DAP(Finetune): |
| def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| super().__init__(backbone, feat_dim, num_class, **kwargs) |
| self.kwargs = kwargs |
| self.network = backbone |
| self.train_mask = kwargs['train_mask'] |
| self.task_inc = kwargs['task_inc'] |
| self.pull_constraint = kwargs['pull_constraint'] |
| self.pull_constraint_coeff = kwargs['pull_constraint_coeff'] |
|
|
| self.task_idx = 0 |
| self.task_data_count = [] |
| self.prompt_center = None |
|
|
| |
| if self.num_class % kwargs['task_num'] != 0: |
| raise ValueError('Number of classes must be divisible by number of tasks') |
| classes_per_task = self.num_class // kwargs['task_num'] |
| self.class_mask = [list(range(i * classes_per_task, (i + 1) * classes_per_task)) for i in range(kwargs['task_num'])] |
|
|
| self.original_model = copy.deepcopy(self.backbone) |
| self.original_model.to(self.device) |
| self.original_model.eval() |
|
|
| if kwargs['freeze']: |
| |
| for p in self.original_model.parameters(): |
| p.requires_grad = False |
|
|
| |
| for n, p in self.network.named_parameters(): |
| if n.startswith(tuple(kwargs['freeze'])): |
| p.requires_grad = False |
|
|
| self.loss_fn.to(self.device) |
|
|
| def observe(self, data, train_gprompt=False, gen=False): |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
|
|
| with torch.no_grad(): |
| if self.original_model is not None: |
| output = self.original_model(x) |
| cls_features = output['pre_logits'] |
| else: |
| cls_features = None |
| if gen: |
| output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True, gen=gen) |
| else: |
| output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True) |
| logits = output['logits'] |
|
|
| |
| if self.train_mask and self.class_mask is not None: |
| mask = self.class_mask[self.task_idx] |
| not_mask = np.setdiff1d(np.arange(self.num_class), mask) |
| not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device) |
| logits = logits.index_fill( |
| dim=1, index=not_mask, value=float('-inf')) |
|
|
| if (train_gprompt): |
|
|
| pla_similarity_loss_res = self.cal_latestsimilarity_loss( |
| model=self.network, task_id=self.task_idx) |
| sta_similarity_loss_res = self.cal_similarity_loss(model=self.network, task_id=self.task_idx, prompt_center=self.prompt_center) |
|
|
| pla_similarity_loss = pla_similarity_loss_res['similarity'] |
| sta_similarity_loss = sta_similarity_loss_res['avg_similarity'] |
|
|
| min_data_count = min(self.task_data_count) |
| max_data_count = max(self.task_data_count) |
| last_data_count = self.task_data_count[-1] |
| epsilon = 1e-10 |
| alpha = (last_data_count - min_data_count) / (max_data_count - min_data_count + epsilon) |
|
|
| loss2 = alpha*sta_similarity_loss |
| loss3 = (1-alpha)*pla_similarity_loss |
|
|
| loss = self.loss_fn(logits, y) + loss2 + loss3 |
|
|
| else: |
| |
| loss = self.loss_fn(logits, y) |
| if self.pull_constraint and 'reduce_sim' in output: |
| loss = loss - self.pull_constraint_coeff * output['reduce_sim'] |
|
|
| if not math.isfinite(loss.item()): |
| raise RuntimeError(f'Loss is {loss.item()}, stopping training') |
|
|
| pred = torch.argmax(logits, dim=1) |
| acc = torch.sum(pred == y).item() |
|
|
| return pred, acc / x.size(0), loss |
|
|
| def inference(self, data): |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
|
|
| with torch.no_grad(): |
| if self.original_model is not None: |
| output = self.original_model(x) |
| cls_features = output['pre_logits'] |
| else: |
| cls_features = None |
| output = self.network(x, task_id=self.task_idx, cls_features=cls_features, gen=True) |
| logits = output['logits'] |
|
|
| |
| if self.task_inc and self.class_mask is not None: |
| mask = self.class_mask[self.task_idx] |
| mask = torch.tensor(mask, dtype=torch.int64).to(self.device) |
| logits_mask = torch.ones_like(logits, device=self.device) * float('-inf') |
| logits_mask = logits_mask.index_fill(1, mask, 0.0) |
| logits = logits + logits_mask |
|
|
| pred = torch.argmax(logits, dim=1) |
| acc = torch.sum(pred == y).item() |
|
|
| return pred, acc / x.size(0) |
|
|
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| self.task_idx = task_idx |
| self.network.task_id = task_idx |
| self.task_data_count.append(len(train_loader.dataset)) |
|
|
| @staticmethod |
| def cal_latestsimilarity_loss(model: torch.nn.Module, task_id=-1): |
| res = dict() |
| global global_max_dist2 |
|
|
| gprompt = model.prompt.generalprompt |
| tprompt = model.prompt.taskprompt[task_id].detach() |
|
|
| gprompt_flat = gprompt.view(-1) |
| tprompt_tensors = tprompt.view(-1) |
| similarity = 1-F.cosine_similarity(gprompt_flat, tprompt_tensors, dim=0) |
| res['similarity'] = similarity |
| return res |
|
|
| @staticmethod |
| def cal_center(model: torch.nn.Module, task_id=-1, task_data_count=None, prompt_center=None): |
| tprompt = model.prompt.taskprompt |
| if task_id > 0: |
| if prompt_center is None: |
| prompt_center = tprompt[0].detach().view(-1) |
| current_tprompt = tprompt[task_id - 1].detach().view(-1) |
| if task_data_count: |
| weights = [1 / count for count in task_data_count[:task_id]] |
| normalized_weight = weights[-1] / sum(weights) |
| weights2 = sum(weights[:-1]) / sum(weights) |
| else: |
| normalized_weight = 1.0 / task_id |
| prompt_center = (prompt_center * weights2) + \ |
| (current_tprompt * normalized_weight) |
| else: |
| prompt_center = torch.zeros_like(tprompt[0].detach().view(-1)) |
| return prompt_center |
|
|
| @staticmethod |
| def cal_similarity_loss(model: torch.nn.Module, task_id=-1, prompt_center=None): |
| res = dict() |
| global global_max_dist |
|
|
| gprompt = model.prompt.generalprompt |
|
|
| if task_id > 0: |
| gprompt_flat = gprompt.view(-1) |
| similarity = 1-F.cosine_similarity(gprompt_flat, prompt_center, dim=0) |
| res['similarity'] = similarity |
| res['avg_similarity'] = similarity |
| else: |
| res['similarity'] = torch.tensor(0) |
| res['avg_similarity'] = 0 |
| return res |