| """ |
| @article{He_2025_CVPR, |
| author = {He, Jiangpeng and Duan, Zhihao and Zhu, Fengqing}, |
| title = {CL-LoRA: Continual Low-Rank Adaptation for Rehearsal-Free Class-Incremental Learning}, |
| journal = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, |
| month = {June}, |
| year = {2025}, |
| pages = {30534-30544} |
| } |
| |
| Adapted from https://github.com/JiangpengHe/CL-LoRA |
| """ |
|
|
| import math |
| import torch |
|
|
| import numpy as np |
| import torch.nn as nn |
|
|
| from tqdm import tqdm |
| from torch import optim |
| from copy import deepcopy |
| from torch.nn import functional as F |
|
|
| from .backbone.transformer import MultiHeadAttention_CL_LoRA |
|
|
| def _KD_loss(pred, soft, T): |
| pred = torch.log_softmax(pred / T, dim=1) |
| soft = torch.softmax(soft / T, dim=1) |
| return -1 * torch.mul(soft, pred).sum() / pred.shape[0] |
|
|
| def compute_orthogonality_loss(previous_weights_list, current_weights, epsilon=1e-8): |
| total_ortho_loss = 0.0 |
| current_norm = torch.norm(current_weights.flatten()) |
| current_normalized = current_weights.flatten() / (current_norm + epsilon) |
|
|
| for prev_weights in previous_weights_list: |
| |
| prev_norm = torch.norm(prev_weights.flatten()) |
| prev_normalized = prev_weights.flatten() / (prev_norm + epsilon) |
|
|
| |
| dot_product = torch.abs(torch.sum(prev_normalized * current_normalized)) |
|
|
| total_ortho_loss += dot_product |
|
|
| |
| if len(previous_weights_list) > 0: |
| total_ortho_loss /= len(previous_weights_list) |
|
|
| return total_ortho_loss |
|
|
| class CosineLinearFeature(nn.Module): |
| def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True): |
| super(CosineLinearFeature, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features * nb_proxy |
| self.nb_proxy = nb_proxy |
| self.to_reduce = to_reduce |
| self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) |
| if sigma: |
| self.sigma = nn.Parameter(torch.Tensor(1)) |
| else: |
| self.register_parameter('sigma', None) |
| self.reset_parameters() |
| |
| def reset_parameters(self): |
| stdv = 1. / math.sqrt(self.weight.size(1)) |
| self.weight.data.uniform_(-stdv, stdv) |
| if self.sigma is not None: |
| self.sigma.data.fill_(1) |
| |
| def reset_parameters_to_zero(self): |
| self.weight.data.fill_(0) |
|
|
| def forward(self, input): |
| out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) |
|
|
| if self.to_reduce: |
| |
| out = reduce_proxies(out, self.nb_proxy) |
|
|
| if self.sigma is not None: |
| out = self.sigma * out |
|
|
| return {'logits': out} |
|
|
| def forward_diagonal(self, input, cur_task, alpha=0., beta=0.0, init_cls=10, inc=10, out_dim=768, use_init_ptm=False): |
| for i in range(cur_task + 1): |
| if i == 0: |
| start_cls = 0 |
| end_cls = init_cls |
| else: |
| start_cls = init_cls + (i - 1) * inc |
| end_cls = start_cls + inc |
| input1 = F.normalize(input[:, i * out_dim:(i + 1) * out_dim], p=2, dim=1) |
| weight1 = F.normalize(self.weight[start_cls:end_cls, i * out_dim:(i + 1) * out_dim], p=2, dim=1) |
|
|
| out = F.linear(input1, weight1) |
| if i == 0: |
| out_all = out |
| else: |
| out_all = torch.cat((out_all, out), dim=1) if i != 0 else out |
|
|
| if self.to_reduce: |
| |
| out_all = reduce_proxies(out_all, self.nb_proxy) |
|
|
| if self.sigma is not None: |
| out_all = self.sigma * out_all |
|
|
| return {'logits': out_all} |
|
|
| class Model(nn.Module): |
| def __init__(self, backbone, device, **kwargs): |
| super().__init__() |
| self.backbone = backbone |
| self.inc = kwargs["inc_cls_num"] |
| self.init_cls = kwargs["init_cls_num"] |
| self._cur_task = -1 |
| self.out_dim = 768 |
| self.fc = None |
| self.alpha = 0. |
| self.beta = 0 |
| self.fc_list = nn.ModuleList() |
| self.fc_list_task = nn.ModuleList() |
| self.adapter_list = nn.ModuleList() |
| self.init_proto = None |
|
|
| self._device = device |
|
|
| def freeze(self): |
| for name, param in self.named_parameters(): |
| param.requires_grad = False |
| |
| @property |
| def feature_dim(self): |
|
|
| return self.out_dim * (self._cur_task + 1) |
|
|
| def update_fc(self, nb_classes): |
| self._cur_task += 1 |
| |
| if self._cur_task == 0: |
| self.proxy_fc = self.generate_fc(self.out_dim, self.init_cls).to(self._device) |
| else: |
| self.proxy_fc = self.generate_fc(self.out_dim, self.inc).to(self._device) |
| init_proto = self.generate_fc(self.out_dim, nb_classes).to(self._device) |
|
|
| if self.init_proto is not None: |
| old_nb_classes = self.init_proto.out_features |
| weight = deepcopy(self.init_proto.weight.data) |
| init_proto.weight.data[: old_nb_classes, :] = nn.Parameter(weight) |
| del self.init_proto |
| self.init_proto = init_proto |
|
|
| fc = self.generate_fc(self.feature_dim, nb_classes).to(self._device) |
| fc.reset_parameters_to_zero() |
| |
| if self.fc is not None: |
| old_nb_classes = self.fc.out_features |
| weight = deepcopy(self.fc.weight.data) |
| fc.sigma.data = self.fc.sigma.data |
| fc.weight.data[: old_nb_classes, : -self.out_dim] = nn.Parameter(weight) |
|
|
| del self.fc |
| self.fc = fc |
| self.fc.requires_grad_(False) |
|
|
| def add_fc(self): |
| self.fc_list.append(self.proxy_fc.requires_grad_(False)) |
| del self.proxy_fc |
|
|
| def generate_fc(self, in_dim, out_dim): |
| fc = CosineLinearFeature(in_dim, out_dim) |
| return fc |
| |
| def forward_kd(self, x, t_idx): |
| x_new, x_teacher = self.backbone.forward_general_cls(x, t_idx) |
| out_new, out_teacher = self.proxy_fc(x_new), self.proxy_fc(x_teacher) |
| return out_new, out_teacher |
|
|
| def forward(self, x, test=False): |
| if test == False: |
| x = self.backbone.forward(x, test=False) |
| out = self.proxy_fc(x) |
| out.update({"features": x}) |
| return out |
| else: |
|
|
| x_input = self.backbone.forward(x, test=True) |
| out = self.fc.forward_diagonal(x_input, cur_task=self._cur_task, alpha=0., init_cls=self.init_cls, inc=self.inc, use_init_ptm=False, beta=0) |
| out.update({"features": x_input}) |
|
|
| return out |
|
|
| class CL_LoRA(nn.Module): |
|
|
| def __init__(self, backbone, device, **kwargs): |
|
|
| super().__init__() |
|
|
| self.device = device |
| self.init_cls_num = kwargs["init_cls_num"] |
| self.inc_cls_num = kwargs["inc_cls_num"] |
| self.task_num = kwargs["task_num"] |
| self._known_classes = 0 |
| self._total_classes = 0 |
| self._cur_task = 0 |
|
|
| self._network = Model(backbone, device, **kwargs) |
| self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_CL_LoRA)] |
|
|
| self.lora_modules = [[] for _ in range(self.task_num)] |
| self.lora_scale_weights = [[] for _ in range(self.task_num)] |
| self.optim = None |
|
|
| def observe(self, data): |
|
|
| x, y = data['image'].to(self.device), data['label'].to(self.device) |
|
|
| aux_targets = y - self._known_classes |
|
|
| logits = self._network(x, test=False)['logits'] |
| loss = F.cross_entropy(logits, aux_targets) |
|
|
| if self._cur_task > 0: |
| |
| kd_ratio = 5. |
| Temperature = 2 |
|
|
| out_new, out_teacher = self._network.forward_kd(x, self._cur_task) |
| out_new_logits = out_new["logits"] |
| out_teacher_logits = out_teacher["logits"] |
| loss_kd = kd_ratio * _KD_loss(out_new_logits, out_teacher_logits, T=Temperature) |
|
|
| self.optim.zero_grad() |
| loss_kd.backward() |
|
|
| for j in range(len(self._network.backbone.feat.general_pos)): |
| pos = self._network.backbone.feat.adapt_pos.index(self._network.backbone.feat.general_pos[j]) |
| for jj in range(len(self._network.backbone.feat.msa)): |
| if self._network.backbone.feat.msa[jj] == 1: |
| temp_weights = 1. * torch.norm(self._network.backbone.feat.old_adapter_list[self._cur_task-1][pos][jj].lora_A.weight,dim=1) |
| temp_weights = 1. * len(temp_weights) * temp_weights / torch.sum(temp_weights) |
| self._network.backbone.feat.cur_adapter[pos][jj].lora_A.weight.grad = temp_weights.unsqueeze(1) * self._network.backbone.feat.cur_adapter[pos][jj].lora_A.weight.grad |
|
|
| self.optim.step() |
|
|
| orth_loss_specific = compute_orthogonality_loss(self._network.backbone.feat.block_weight_list, self._network.backbone.feat.block_weight) |
| loss += 0.0001 * orth_loss_specific |
|
|
| preds = logits.max(1)[1] |
| correct_count = preds.eq(aux_targets).sum().item() |
| acc = correct_count / y.size(0) |
|
|
| return preds, acc, loss |
| |
| def inference(self, data): |
|
|
| x, y = data['image'].to(self.device), data['label'].to(self.device) |
|
|
| logits = self._network(x, True)["logits"] |
| preds = logits.max(1)[1] |
|
|
| correct_count = preds.eq(y).sum().item() |
| acc = correct_count / y.size(0) |
|
|
| return preds, acc |
| |
| @torch.no_grad() |
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
|
|
| if task_idx > 0: |
| self._known_classes = self._total_classes |
| self._network.freeze() |
| self._network.backbone.add_adapter_to_list() |
|
|
| self._cur_task = task_idx |
| self._total_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num |
| self._network.update_fc(self._total_classes) |
|
|
| for name, param in self._network.named_parameters(): |
| if 'backbone.feat.cur_adapter' in name or 'proxy_fc.' in name or 'init_proto' in name: |
| param.requires_grad_(True) |
| else: |
| param.requires_grad_(False) |
|
|
| param.requires_grad_(False) |
|
|
| if 'lora' in name and 'cur_adapter' in name: |
| if any(f'er.{i}.' in name for i in range(6)) and 'lora_B' in name and 'cur_adapter': |
| pass |
| else: |
| param.requires_grad_(True) |
|
|
| elif f'proxy_fc' in name: |
| param.requires_grad_(True) |
| elif 'init_proto' in name: |
| param.requires_grad_(True) |
| elif 'block_weight' in name and 'old' not in name: |
| param.requires_grad_(True) |
|
|
| self._network = self._network.to(self.device) |
|
|
| @torch.no_grad() |
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
|
|
| self._network.add_fc() |
| train_loader.dataset.trfms = test_loaders[0].dataset.trfms |
| self.replace_fc(train_loader) |
|
|
| self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num |
|
|
| def replace_fc(self, train_loader): |
| model = self._network |
| model = model.eval() |
|
|
| with torch.no_grad(): |
| for index in range(0, self._cur_task + 1): |
| embedding_list, label_list = [], [] |
| for i, batch in enumerate(train_loader): |
| data, label = batch['image'], batch['label'] |
| data = data.to(self.device) |
| label = label.to(self.device) |
| embedding = model.backbone.forward_proto(data, adapt_index=index) |
| embedding_list.append(embedding.cpu()) |
| label_list.append(label.cpu()) |
|
|
| embedding_list = torch.cat(embedding_list, dim=0) |
| label_list = torch.cat(label_list, dim=0) |
|
|
| class_list = np.unique(train_loader.dataset.labels) |
| for class_index in class_list: |
| data_index = (label_list == class_index).nonzero().squeeze(-1) |
| embedding = embedding_list[data_index] |
| proto = embedding.mean(0) |
| model.fc.weight.data[class_index, index*self._network.out_dim:(index+1)*self._network.out_dim] = proto |
|
|
| def get_parameters(self, config): |
| return self._network.parameters() |
|
|
| def set_optim(self, optim): |
| self.optim = optim |