|
|
| """
|
| @inproceedings{zhao2020maintaining,
|
| title={Maintaining discrimination and fairness in class incremental learning},
|
| author={Zhao, Bowen and Xiao, Xi and Gan, Guojun and Zhang, Bin and Xia, Shu-Tao},
|
| booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)},
|
| pages={13208--13217},
|
| year={2020}
|
| }
|
| https://arxiv.org/abs/1911.07053
|
|
|
| Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/wa.py, https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py.
|
| """
|
|
|
| import torch
|
| from torch import nn
|
| import copy
|
| from torch.nn import functional as F
|
| import numpy as np
|
| from .finetune import Finetune
|
|
|
|
|
| def KD_loss(pred, soft, T=2):
|
| '''
|
| Code Reference:
|
| https://github.com/G-U-N/PyCIL/blob/master/models/wa.py
|
|
|
| Compute the knowledge distillation loss.
|
|
|
| Args:
|
| pred (torch.Tensor): Predictions of the model.
|
| soft (torch.Tensor): Soft targets.
|
| T (float): Temperature parameter for softening the predictions. Default is 2.
|
|
|
| Returns:
|
| torch.Tensor: Knowledge distillation loss.
|
| '''
|
| pred = torch.log_softmax(pred / T, dim=1)
|
| soft = torch.softmax(soft / T, dim=1)
|
| return -1 * torch.mul(soft, pred).sum() / pred.shape[0]
|
|
|
|
|
| class IncrementalModel(nn.Module):
|
| '''
|
| Code Reference:
|
| https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py
|
|
|
| A model consists with a backbone and a classifier.
|
|
|
| Args:
|
| backbone (nn.Module): Backbone network.
|
| feat_dim (int): Dimension of the extracted features.
|
| num_class (int): Number of classes in the dataset.
|
| '''
|
| def __init__(self, backbone, feat_dim, num_class):
|
| super().__init__()
|
| self.backbone = backbone
|
| self.feat_dim = feat_dim
|
| self.num_class = num_class
|
| self.classifier = None
|
|
|
| def forward(self, x):
|
| return self.get_logits(x)
|
|
|
| def get_logits(self, x):
|
| '''
|
| Compute logits for the input data.
|
|
|
| Args:
|
| x (torch.Tensor): Input data.
|
|
|
| Returns:
|
| torch.Tensor: Logits of the input data.
|
| '''
|
| logits = self.classifier(self.backbone(x)['features'])
|
| return logits
|
|
|
| def update_classifier(self, number_classes):
|
| '''
|
| Incrementally update the classifier with deepcopy.
|
|
|
| Args:
|
| number_classes (int): Number of classes after update.
|
| '''
|
| classifier = nn.Linear(self.feat_dim, number_classes)
|
| if self.classifier is not None:
|
| number_output = self.classifier.out_features
|
| weight = copy.deepcopy(self.classifier.weight.data)
|
| bias = copy.deepcopy(self.classifier.bias.data)
|
| classifier.weight.data[:number_output] = weight
|
| classifier.bias.data[:number_output] = bias
|
|
|
| del self.classifier
|
| self.classifier = classifier
|
|
|
| def classifier_weight_align(self, incremental_number):
|
| '''
|
| Align the weight of the classifier after every task.
|
|
|
| Args:
|
| incremental_number (int): Number of classes added in the current task.
|
| '''
|
| weights = self.classifier.weight.data
|
| new_norm = torch.norm(weights[-incremental_number:, :], p=2, dim=1)
|
| old_norm = torch.norm(weights[:-incremental_number, :], p=2, dim=1)
|
| new_mean = torch.mean(new_norm)
|
| old_mean = torch.mean(old_norm)
|
| gamma = old_mean / new_mean
|
| self.classifier.weight.data[-incremental_number:, :] *= gamma
|
|
|
| def forward(self, x):
|
| return self.get_logits(x)
|
|
|
| def get_logits(self, x):
|
| logits = self.classifier(self.backbone(x)['features'])
|
| return logits
|
|
|
| def freeze(self):
|
| '''
|
| Freeze the model parameters.
|
| '''
|
| for param in self.parameters():
|
| param.requires_grad = False
|
| self.eval()
|
|
|
| return self
|
|
|
| def extract_vector(self, x):
|
| '''
|
| Extract features from the backbone network.
|
|
|
| Args:
|
| x (torch.Tensor): Input data.
|
|
|
| Returns:
|
| torch.Tensor: Extracted features.
|
| '''
|
| return self.backbone(x)["features"]
|
|
|
|
|
| class WA(Finetune):
|
| def __init__(self, backbone, feat_dim, num_class, **kwargs):
|
| super().__init__(backbone, feat_dim, num_class, **kwargs)
|
| self.network = IncrementalModel(self.backbone, feat_dim, kwargs['init_cls_num'])
|
| self.device = kwargs['device']
|
| self.old_network = None
|
| self.known_classes = 0
|
| self.total_classes = 0
|
| self.task_idx = 0
|
|
|
| self.total_classes_indexes = 0
|
|
|
| def observe(self, data):
|
| '''
|
| Do every current task.
|
|
|
| Args:
|
| data (dict): Dictionary containing input data and labels.
|
|
|
| Returns:
|
| tuple: Tuple containing predictions, accuracy, and loss.
|
| '''
|
| x, y = data['image'].to(self.device), data['label'].to(self.device)
|
|
|
| self.network.to(self.device)
|
| if self.old_network:
|
| self.old_network.to(self.device)
|
|
|
| logits = self.network(x)
|
| loss = F.cross_entropy(logits, y)
|
|
|
| if self.task_idx > 0:
|
| kd_lambda = self.known_classes / self.total_classes
|
| loss_kd = KD_loss(
|
| logits[:, : self.known_classes],
|
| self.old_network(x),
|
| )
|
| loss = (1 - kd_lambda) * loss + kd_lambda * loss_kd
|
|
|
|
|
| pred = torch.argmax(logits, dim=1)
|
| acc = torch.sum(pred == y).item()
|
|
|
| return pred, acc / x.size(0), loss
|
|
|
| def inference(self, data):
|
| '''
|
| Perform inference on the input data.
|
|
|
| Args:
|
| data (dict): Dictionary containing input data and labels.
|
|
|
| Returns:
|
| tuple: Tuple containing predictions and accuracy.
|
| '''
|
| x, y = data['image'].to(self.device), data['label'].to(self.device)
|
|
|
| logits = self.network(x)
|
| pred = torch.argmax(logits, dim=1)
|
| acc = torch.sum(pred == y).item()
|
| return pred, acc / x.size(0)
|
|
|
| def forward(self, x):
|
| return self.network(x)
|
|
|
| def before_task(self, task_idx, buffer, train_loader, test_loaders):
|
| '''
|
| Do before every task for task initialization.
|
|
|
| Args:
|
| task_idx (int): Index of the current task.
|
| buffer (Buffer): Buffer object.
|
| train_loader (DataLoader): DataLoader for training data.
|
| test_loaders (list): List of DataLoaders for test data.
|
| '''
|
| self.total_classes += self.kwargs['init_cls_num']
|
| self.network.update_classifier(self.total_classes)
|
|
|
| self.total_classes_indexes = np.arange(self.known_classes, self.total_classes)
|
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders):
|
| '''
|
| Do after every task for updating the model.
|
|
|
| Args:
|
| task_idx (int): Index of the current task.
|
| buffer (Buffer): Buffer object.
|
| train_loader (DataLoader): DataLoader for training data.
|
| test_loaders (list): List of DataLoaders for test data.
|
| '''
|
| if self.task_idx > 0:
|
| self.network.classifier_weight_align(self.total_classes - self.known_classes)
|
| self.old_network = copy.deepcopy(self.network).freeze()
|
| self.known_classes = self.total_classes
|
|
|
|
|
| buffer.reduce_old_data(self.task_idx, self.total_classes)
|
| val_transform = test_loaders[0].dataset.trfms
|
| buffer.update(self.network, train_loader, val_transform,
|
| self.task_idx, self.total_classes, self.total_classes_indexes,
|
| self.device)
|
|
|
| self.task_idx += 1
|
|
|