diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..2caacc638f4eb06c6234dcf6ab5afc672589c6b9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +resources/imgs/cl_lora.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/codaprompt.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/dap.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/flowchart.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/InfLoRA.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/LUCIR.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/moe_adapter4cl.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/praka.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/rapf.png filter=lfs diff=lfs merge=lfs -text +resources/imgs/wa.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..02aed090961efc8676bf6347d26e2558d2fbdd81 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 VIG@R&L + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..4672be93b60794052f35728af814bee248f308a4 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,178 @@ ---- -license: mit ---- + +
+

LibContinual: Make Continual Learning Easy

+ + +
+ +## Introduction +
+LibContinual is an open-source continual learning toolbox based on PyTorch. The framework currently supports PyTorch 1.13+ (compatibility with earlier versions not fully guaranteed) and provides comprehensive implementations of state-of-the-art continual learning algorithms. +
+ + + + +## Supported Methods +### Conventional methods ++ [LwF (ECCV 2016)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/lwf/README.md): Learning without Forgetting. ++ [EWC (PNAS 2017)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/ewc/README.md): Overcoming catastrophic forgetting in neural networks. ++ [iCaRL (CVPR 2017)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/icarl/README.md): Incremental Classifier and Representation Learning. ++ [BiC (CVPR 2019)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/bic/README.md): Large Scale Incremental Learning. ++ [LUCIR (CVPR 2019)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/lucir/README.md): Learning a Unified Classifier Incrementally via Rebalancing. ++ [WA (CVPR 2020)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/wa/README.md): Maintaining Discrimination and Fairness in Class Incremental Learning. ++ [OCM (ICML 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/ocm/README.md): Online Continual Learning through Mutual Information Maximization. ++ [ERACE, ERAML (ICLR 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/erace,eraml/README.md): New Insights on reducing abrupt representation change in online continual learning. ++ [GPM (ICLR 2021)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/gpm/README.md): Gradient Projection Memory for Continual Learning. ++ [TRGP (ICLR 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/trgp/README.md): Trust Region Gradient Projection for Continual Learning. ++ [API (CVPR 2023)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/api/README.md): Adaptive Plasticity Improvement for Continual Learning. ++ [RanPAC (NeurIPS 2023)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/ranpac/README.md): Random Projections and Pre-trained Models for Continual Learning. + + + +### Foundation model based methods ++ [L2P (CVPR 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/l2p/README.md): Learning to Prompt for Continual Learning. ++ [DualPrompt (ECCV 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/dualprompt/README.md): Complementary Prompting for Rehearsal-free Continual Learning. ++ [CodaPrompt (CVPR 2023)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/codaprompt/README.md): COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning. ++ [InfLoRA (CVPR 2024)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/inflora/README.md): Interference-Free Low-Rank Adaptation for Continual Learning. ++ [MoE_Adapter4CL (CVPR 2024)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/moe_adapter4cl/README.md): Boosting Continual Learning of Vision-Language Models via Mixture-of-Experts Adapters. ++ [RAPF (ECCV 2024)](https://github.com/RL-VIG/LibContinual/tree/master/reproduce/rapf): Class-Incremental Learning with CLIP: Adaptive Representation Adjustment and Parameter Fusion. ++ [SD_LoRA (ICLR 2025)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/sd_lora/README.md): Scalable Decoupled Low-Rank Adaptation for Class Incremental Learning ++ [LoRA_Sub_DRS (CVPR 2025)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/lora_sub_drs/README.md): LoRA Subtraction for Drift-Resistant Space in Exemplar-Free Continual Learning ++ [CL-LoRA (CVPR 2025)](https://github.com/RL-VIG/LibContinual/tree/master/reproduce/cl_lora/README.md): Continual Low-Rank Adaptation for Rehearsal-Free Class-Incremental Learning + + + +## Installation +Please refer to [`install.md`](https://libcontinual.readthedocs.io/en/latest/docs/install.html)
+Complete tutorials can be found at [`./docs`](https://libcontinual.readthedocs.io/en/latest/) + + +## Datasets +- CIFAR-10 is avaliable at [Google Drive](https://drive.google.com/drive/folders/1sl2aW1sRpEfQJuJZwajXO2QhR06gQYZx?usp=drive_link)
+- CIFAR-100 is available at [Google Drive](https://drive.google.com/drive/folders/1EL46LQ3ww-F1NVTwFDPIg-nO198cUqWm?usp=sharing)
+- CUB200, ImageNet-R, Tiny-Imagenet, 5-Dataset are available at [Google Drive](https://drive.google.com/drive/folders/16afRW2952coWJSbiH7cZT1b8pRibA8nH?usp=sharing)
+ +After the dataset is downloaded, please extract the compressed file to the specified path. +``` +unzip cifar100.zip -d /path/to/your/dataset +``` +Set the `data_root` in `.yaml`: +``` +data_root: /path/to/your/dataset +``` +To add a custom dataset, please refer to [`dataset.md`](https://libcontinual.readthedocs.io/en/latest/docs/data_module_en.html). + +## Getting Started + +Once you have completed the "Installation" and "Datasets" sections, you can now proceed to demonstrate how to use the "LibContinual" framework with the [`LUCIR`](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/lucir/README.md) method. + +- **Step1:** Configure the parameters in the `./config/lucir.yaml` file. Please refer to [`config.md`](https://libcontinual.readthedocs.io/en/latest/docs/config_file_en.html) for the meanings of each parameter. +- **Step2:** Run code `python run_trainer.py --config lucir.yaml` +- **Step3:** After the training is completed, the log files will be saved in the path specified by the `save_path` parameter. + +## Benchmarks + +We adopt standardized evaluation metrics from continual learning literature. Given T tasks where $R_{t,i}$ represents the accuracy of model after training on task $t$ when tested on task $i$: + + + +### Evaluation Metrics + +#### **1.** Last Average Accuracy +$$ + Acc_T=R_{T, {0\sim T}} \quad (1) +$$ + + +#### **2.** Backward Transfer (BWT) +$$ + BWT_T = \frac{\sum_{i=3}^T\sum_{j=1}^{i-2}R_{i,j}-R{j,j}}{T(T-1)/2} \quad (2) +$$ + +#### **3.** Forgetting +$$ + Frgt_T = \frac{\sum_{j=1}^{T-2}R_{T-1,j}-R_{j,j}}{T-1} \quad (3) +$$ +> Equivalent to Positive BwT in ["new metrics for Continual Learning"](https://arxiv.org/pdf/1810.13166) + + +#### **4.** Overall Average Accuracy +$$ + \overline{Acc_T}=\frac{1}{T}\sum_{t=1}^T(\frac{1}{t}\sum_{i=1}^t R_{t,i}) \quad (4) +$$ + + + + +## Acknowledgement +LibContinual is an open source project designed to help continual learning researchers quickly understand the classic methods and code structures. We welcome other contributors to use this framework to implement their own or other impressive methods and add them to LibContinual. This library can only be used for academic research. We welcome any feedback during using LibContinual and will try our best to continually improve the library. +Special thanks to the authors of [FACIL](https://github.com/mmasana/FACIL) and [PyCIL](https://github.com/G-U-N/PyCIL) for their inspiration on framework design. + + + + +We have referenced useful modules from these repositories in our work. We deeply appreciate the authors of these repositories. + +## License +This project is licensed under the MIT License. See LICENSE for more details. diff --git a/README_ZH.md b/README_ZH.md new file mode 100644 index 0000000000000000000000000000000000000000..d78176fffe124c96f3d3769ebbaa3f0b13a1b325 --- /dev/null +++ b/README_ZH.md @@ -0,0 +1,62 @@ +# LibContinual +Make continual learning easy. + +## Introduction +LibContinual is an open source continual learning toolbox based on PyTorch. The master branch works with PyTorch 1.13. The compatibility to earlier versions of PyTorch is not fully tested. + +![flowchart](./resources/imgs/flowchart.png) + +## Supported Methods ++ [BiC (CVPR 2019)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/bic/README.md) ++ [EWC (PNAS 2017)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/ewc/README.md) ++ [iCaRL (CVPR2017)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/icarl/README.md) ++ [LUCIR (CVPR 2019)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/lucir/README.md) ++ [LwF (ECCV 2016)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/lwf/README.md) ++ [WA (CVPR 2020)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/wa/README.md) ++ [OCM (PMLR 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/ocm/README.md) ++ [DER (CVPR 2021)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/der/README.md) ++ [ERACE,ERAML (ICLR 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/erace,eraml/README.md) ++ [L2P (CVPR 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/l2p/README.md) ++ [DualPrompt (ECCV 2022)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/dualprompt/README.md) ++ [CodaPrompt (CVPR 2023)](https://github.com/RL-VIG/LibContinual/blob/master/reproduce/codaprompt/README.md) + +## Quick Installation +(待文档部分完成)
+请参考文档中[`安装`](https://github.com/RL-VIG/LibContinual/blob/master/docs/tutorials/install.md)部分。
+完整文档:[`./docs`](https://github.com/RL-VIG/LibContinual/tree/master/docs) + +## Datasets +[`CIFAR-100`](https://drive.google.com/drive/folders/1EL46LQ3ww-F1NVTwFDPIg-nO198cUqWm?usp=sharing), `miniImageNet(todo)`
+ +将对应数据集的压缩包解压至指定路径: +``` +unzip cifar100.zip -d /path/to/your/dataset +``` +修改.yaml文件的data_root参数: +``` +data_root: /path/to/your/dataset +``` +如何添加自定义数据集请参考文档:[`添加自定义数据集`](https://github.com/RL-VIG/LibContinual/blob/master/docs/tutorials/zh/data_module_zh.md) + +## Get Start + +当您已经完成`Quick Installation`和`Datasets`后,我们以`LUCIR`方法为例展示如何使用`LibContinual`。 +- **Step1**: 修改`run_trainer.py`中`Config`参数为`./config/lucir.yaml` +- **Step2**:配置`./config/lucir.yaml`文件中的参数,各参数含义请参考[配置文件](https://github.com/RL-VIG/LibContinual/blob/master/docs/tutorials/config_file.md) +- **Step3**: 运行代码`python run_trainer.py` +- **Step4**:日志保存在配置文件中`save_path`路径下 + + +## Acknowledgement +LibContinual is an open source project designed to help continual learning researchers quickly understand the classic methods and code structures. We welcome other contributors to use this framework to implement their own or other impressive methods and add them to LibContinual. This library can only be used for academic research. We welcome any feedback during using LibContinual and will try our best to continually improve the library. + + +在本项目开发过程中参考了下列仓库: + +- [FACIL](https://github.com/mmasana/FACIL) +- [PyCIL](https://github.com/G-U-N/PyCIL) + +在我们的工作中参考了这些仓库中有用的模块。我们深深感谢这些仓库的作者们。 + +## License +This project is licensed under the MIT License. See LICENSE for more details. \ No newline at end of file diff --git a/config/InfLoRA.yaml b/config/InfLoRA.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c279a8c45ad28315cdf60e46dc49def5a9ca879 --- /dev/null +++ b/config/InfLoRA.yaml @@ -0,0 +1,64 @@ +save_path : "" + +image_size: 224 +save_path: ./ +init_cls_num: 10 +inc_cls_num: 10 +task_num: 10 +val_per_epoch: 20 + + +epoch: 20 +n_gpu: 1 +seed: 2 + +shuffle: false + +batch_size: 128 + + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + + + + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: SiNet_vit + kwargs: + total_sessions: 10 + rank: 10 + init_cls: 10 + embd_dim: 768 + + +buffer: + name: LinearBuffer + kwargs: + buffer_size: 0 + batch_size: 128 + strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: InfLoRA + kwargs: + feat_dim: 64 + num_class: 100 + inc_cls_num: 10 + # device: 0 + lame: 1.0 + lamb: 0.95 + total_sessions: 10 + + + + diff --git a/config/InfLoRA_opt-vit-cifar100-b10-10-10.yaml b/config/InfLoRA_opt-vit-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ccf016ac77121d622c2363140b5f38a269f8c51 --- /dev/null +++ b/config/InfLoRA_opt-vit-cifar100-b10-10-10.yaml @@ -0,0 +1,72 @@ +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: 20 +val_per_epoch: 20 + +batch_size: 128 + +setting: task-agnostic +seed: 1993 + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_LoRA + lora_rank: 10 + +classifier: + name: InfLoRA_OPT + kwargs: + use_ca: False + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + \ No newline at end of file diff --git a/config/InfLoRA_opt-vit-imagenetr-b20-20-10.yaml b/config/InfLoRA_opt-vit-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..153b240c41b6cea55179b89ba8c18c30b8dd819b --- /dev/null +++ b/config/InfLoRA_opt-vit-imagenetr-b20-20-10.yaml @@ -0,0 +1,73 @@ + +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" + +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch + +workers: 24 + +seed: 42 + +batch_size: 128 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.333] + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 + interpolation: BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 8e-3 + momentum: 0.9 + +lr_scheduler: + name: Constant + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k # vit_base_patch16_224 + attn_layer: MultiHeadAttention_LoRA + lora_rank: 10 + +classifier: + name: InfLoRA_OPT + kwargs: + use_ca: False + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + \ No newline at end of file diff --git a/config/InfLoRA_opt.yaml b/config/InfLoRA_opt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ac7fa8ff01d6ac53c2925578965b2b23c97e779 --- /dev/null +++ b/config/InfLoRA_opt.yaml @@ -0,0 +1,71 @@ +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: 20 +val_per_epoch: 20 + +batch_size: 128 # 128 + +setting: task-agnostic + +testing_times: 5 + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_LoRA + lora_rank: 10 + +classifier: + name: InfLoRA_OPT + kwargs: + use_ca: False + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + \ No newline at end of file diff --git a/config/InfLoRA_opt_clip.yaml b/config/InfLoRA_opt_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e61e5aa180849dfcef8272da240c6c2bb253a7d --- /dev/null +++ b/config/InfLoRA_opt_clip.yaml @@ -0,0 +1,75 @@ +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 +image_size: &image_size 224 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: 20 # 20 +val_per_epoch: 20 + +batch_size: 128 # 128 + +setting: task-agnostic + +testing_times: 5 + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + + +backbone: + name: clip + kwargs: + pretrained : True + model_name : ViT-B/16 + experts_num: 0 + act_layer: QuickGELU + norm_layer: LayerNorm + attn_layer: MultiHeadAttention_LoRA + +classifier: + name: InfLoRA_OPT + kwargs: + use_ca: False + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + prompt_template : "a bad photo of a {}." # For CLIP + visual_only: True # For CLIP, apply lora to only visual encoder or visual and text encoder \ No newline at end of file diff --git a/config/PRAKA.yaml b/config/PRAKA.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c80dcc32267b8a3e7f2ca0d85c67353607f0a397 --- /dev/null +++ b/config/PRAKA.yaml @@ -0,0 +1,78 @@ + +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 50 # 50 50 +inc_cls_num: &inc_cls_num 10 # 5 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 6 # 11 6 +image_size: &image_size 32 + +image_size: *image_size +# data +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num +batch_size: 128 # 128 +epoch: 100 # 100 + +val_per_epoch: 100 +seed: 2 + +testing_times: 10 # 10 + +train_trfms: + - RandomCrop: + size: [*image_size, *image_size] + padding: 4 + - RandomHorizontalFlip: + p: 0.5 + - ColorJitter: + brightness: 0.24705882352941178 + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] # don't change + std: [0.2675, 0.2565, 0.2761] # don't change + #mean: [0.5071, 0.4867, 0.4408] + #std: [0.2675, 0.2565, 0.2761] + + +test_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] # don't change + std: [0.2675, 0.2565, 0.2761] # don't change + #mean: [0.5071, 0.4867, 0.4408] + #std: [0.2675, 0.2565, 0.2761] + +optimizer: + name: Adam + kwargs: + lr: 0.001 + #betas: [0.9, 0.999] + weight_decay: 2e-4 + #eps: 1e-8 + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: 32 + +backbone: + name: resnet18_cbam + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +classifier: + name: PRAKA + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + feat_dim: 512 + + log_root: log + total_nc: *total_cls_num + protoAug_weight: 15.0 + kd_weight: 15.0 + temp: 0.1 \ No newline at end of file diff --git a/config/api_til-alexnet-cifar100-b5-5-20.yaml b/config/api_til-alexnet-cifar100-b5-5-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e967e0642dcbe11d6a93e47724294eb181b201e6 --- /dev/null +++ b/config/api_til-alexnet-cifar100-b5-5-20.yaml @@ -0,0 +1,39 @@ +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 20 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-aware # task-aware +seed: 2 +testing_times: 1 # Don't set too high, it will take eternity + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_API + kwargs: + +classifier: + name: API + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/backbones/CifarResnet.yaml b/config/backbones/CifarResnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c882f42a38ad493c40e53877783fb65bac05ce0f --- /dev/null +++ b/config/backbones/CifarResnet.yaml @@ -0,0 +1,7 @@ +backbone: + name: CifarResnet + kwargs: + keep_prob: 0.0 + avg_pool: True + is_flatten: True + maxpool_last2: True \ No newline at end of file diff --git a/config/backbones/resnet12.yaml b/config/backbones/resnet12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ddc4ef8a6180ac10e1154b2152857a8ce2ce69bb --- /dev/null +++ b/config/backbones/resnet12.yaml @@ -0,0 +1,7 @@ +backbone: + name: resnet12 + kwargs: + keep_prob: 0.0 + avg_pool: True + is_flatten: True + maxpool_last2: True \ No newline at end of file diff --git a/config/codaprompt.yaml b/config/codaprompt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed19d1d18083ddb967919900b81dffcf243190b0 --- /dev/null +++ b/config/codaprompt.yaml @@ -0,0 +1,72 @@ + +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +image_size: *image_size +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 20 +val_per_epoch: 20 + +batch_size: 128 + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.001 + betas: [0.9, 0.999] + weight_decay: 0 + +#lr_scheduler: +# name: MultiStepLR +# kwargs: +# gamma: 0.1 +# milestones: [80, 120] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: vit_pt_imnet + kwargs: + num_classes: *total_cls_num + pretrained: true + model_name : vit_base_patch16_224 + +classifier: + name: CodaPrompt + kwargs: + num_class: *total_cls_num + task_num: *task_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + feat_dim: 768 + prompt_length: 8 + pool_size: 100 + mu: 0.0 + diff --git a/config/dap.yaml b/config/dap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6a3ee6f3baec71fcf287a51509c1dd5eabe0b55 --- /dev/null +++ b/config/dap.yaml @@ -0,0 +1,74 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + +data_root: data/cifar100/ +image_size: 224 +save_path: ./ +init_cls_num: 10 +inc_cls_num: 10 +task_num: 10 + +epoch: 5 +n_gpu: 1 +seed: 42 +val_per_epoch: 5 + +imb_type: exp_re +imb_factor: 0.002 +shuffle: false + +batch_size: 64 + +optimizer: + name: Adam + kwargs: + lr: 0.01 + eps: 1e-8 + weight_decay: 0.0 + betas: [0.9, 0.9] + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [80, 120] + +backbone: + name: vit_pt_imnet_dap + kwargs: + pretrained: true + model_name: vit_base_patch16_224_dap + num_classes: 100 + drop: 0.0 + drop_path: 0.0 + length: 5 + embedding_key: cls + prompt_key_init: uniform + prompt_pool: true + prompt_key: true + size: 10 + top_k: 5 + batchwise_prompt: true + head_type: prompt + use_prompt_mask: false + +classifier: + name: DAP + kwargs: + num_class: 100 + feat_dim: 768 + task_num: 10 + init_cls_num: 10 + inc_cls_num: 10 + train_mask: true + task_inc: false + freeze: + - blocks + - patch_embed + - cls_token + - norm + - pos_embed + pull_constraint: true + pull_constraint_coeff: 0.1 \ No newline at end of file diff --git a/config/der.yaml b/config/der.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4632bd0a03eec7dd5aa638682deb3b2701540d8b --- /dev/null +++ b/config/der.yaml @@ -0,0 +1,59 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + - headers/optimizer.yaml + - backbones/resnet12.yaml + +image_size: 32 + +save_path: ./ +# data +init_cls_num: 10 +inc_cls_num: 10 +task_num: 10 + +# init_epoch can be none +init_epoch: 170 #200 +epoch: 170 #170 +device_ids: 2 +n_gpu: 1 +val_per_epoch: 170 + +batch_size: 128 + +optimizer: + name: SGD + kwargs: + momentum: 0.9 + lr: 0.1 + weight_decay: 2e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [70, 100, 150] + +backbone: + name: resnet18 + kwargs: + num_classes: 10 + args: + dataset: cifar100 + + +buffer: + name: LinearBuffer + kwargs: + buffer_size: 0 # 2000 + batch_size: 32 + strategy: random # random, equal_random, reservoir, herding + +classifier: + name: DER + kwargs: + num_class: 100 + feat_dim: 512 + init_cls_num: 10 + inc_cls_num: 10 diff --git a/config/dmnsp.yaml b/config/dmnsp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ba3b8d43a656d8b5169d5d57131fc052abbc791 --- /dev/null +++ b/config/dmnsp.yaml @@ -0,0 +1,71 @@ + +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 20 + +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 10 + +# setting: task-agnostic # class-incremental settings +setting: task-aware # task-incremental settings + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: DMNSP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + prompt_template : "a bad photo of a {}." + label_smoothing: 0. + lamda_scale: 30 diff --git a/config/dmnsp_imgnr.yaml b/config/dmnsp_imgnr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b095ed5620024283302095d13d26558b83814b5 --- /dev/null +++ b/config/dmnsp_imgnr.yaml @@ -0,0 +1,55 @@ + +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" + +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 200 +task_num: &task_num 20 +image_size: &image_size 224 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 10 + +#setting: task-agnostic +setting: task-aware + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: DMNSP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + prompt_template : "a bad photo of a {}." + label_smoothing: 0. + lamda_scale: 30 diff --git a/config/dmnsp_vit.yaml b/config/dmnsp_vit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d03afd7e73dd5d4c79fef26a66fc75e424d370d --- /dev/null +++ b/config/dmnsp_vit.yaml @@ -0,0 +1,69 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 # 4 +val_per_epoch: 4 # 4 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 10 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + experts_num: 1 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: DMNSP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/dualprompt.yaml b/config/dualprompt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c25d3ea3119a255df8cdf8b1bdc07361a987321 --- /dev/null +++ b/config/dualprompt.yaml @@ -0,0 +1,54 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + # - headers/optimizer.yaml + - backbones/resnet12.yaml + +image_size: 32 + +# data +init_cls_num: 10 +inc_cls_num: 10 +task_num: 10 + +epoch: 10 #160 +device_ids: 0 +n_gpu: 1 +val_per_epoch: 5 + + +batch_size: 128 + + +optimizer: + name: Adam + kwargs: + lr: 0.001 + betas: [0.9, 0.999] + weight_decay: 0 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [80, 120] + +backbone: + name: vit_pt_imnet + kwargs: + num_classes: 100 + pretrained: true + model_name : vit_base_patch16_224 + +classifier: + name: DualPrompt + kwargs: + num_class: 100 + feat_dim: 768 + task_num: 10 + init_cls_num: 10 + inc_cls_num: 10 + g_prompt_length: 6 + e_prompt_length: 20 + diff --git a/config/ewc.yaml b/config/ewc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fdea5892f877572d04a332b3165e4a7adedf309e --- /dev/null +++ b/config/ewc.yaml @@ -0,0 +1,41 @@ +image_size: 32 + +# data +init_cls_num: 10 # 50 +inc_cls_num: 10 # 25 +task_num: 10 # 3 + +epoch: 100 # 100 +n_gpu: 1 +val_per_epoch: 50 + +batch_size: 128 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [60, 120, 170] + +backbone: + name: resnet34 # cifar_resnet32 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + +classifier: + name: EWC + kwargs: + num_class: 100 + feat_dim: 512 # 64 for backbone cifar_resnet32 + init_cls_num: 10 + inc_cls_num: 10 + lamda: 1000 diff --git a/config/finetune.yaml b/config/finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dbcd3cb98a26c746708663629629922c72603dad --- /dev/null +++ b/config/finetune.yaml @@ -0,0 +1,53 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + - headers/optimizer.yaml + - backbones/resnet12.yaml + +data_root: /home/xiongyakun/cifar10 +image_size: 32 + +save_path: ./ +# data +init_cls_num: 2 +inc_cls_num: 2 +task_num: 5 + + +epoch: 3 +device_ids: 5 +n_gpu: 1 +val_per_epoch: 1 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +lr_scheduler: + name: StepLR + kwargs: + gamma: 0.5 + step_size: 10 + +backbone: + name: resnet18 + kwargs: + num_classes: 10 + args: + dataset: cifar10 + + +buffer: + name: LinearBuffer + kwargs: + buffer_size: 500 + batch_size: 32 + strategy: random # random, equal_random, reservoir, herding + +classifier: + name: Finetune + kwargs: + num_class: 10 + feat_dim: 512 \ No newline at end of file diff --git a/config/headers/data.yaml b/config/headers/data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa2de4842af7ac9cb440001b0e0f127aeca2f3d3 --- /dev/null +++ b/config/headers/data.yaml @@ -0,0 +1,7 @@ +save_path: '' +dataset: cifar100 # [cifar100, binary_cifar100, ...] +data_root: /data/lqx/cifar100 +image_size: 32 +pin_memory: False +augment: True +num_workers: 24 \ No newline at end of file diff --git a/config/headers/device.yaml b/config/headers/device.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0282af01ff262c7539fa3ab10b0c978e9a4f2d47 --- /dev/null +++ b/config/headers/device.yaml @@ -0,0 +1,4 @@ +device_ids: 'auto' # determine GPU ids in bus order, auto for automatic select gpu, or number to specify device +n_gpu: 1 # select the number of gpus to use +seed: 1993 # random seed for numpy, torch and cuda +deterministic: True # option for torch.backends.cudnn.benchmark and torch.backends.cudnn.deterministic \ No newline at end of file diff --git a/config/headers/model.yaml b/config/headers/model.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3463f3e36358603252572bee5a355f25334f6633 --- /dev/null +++ b/config/headers/model.yaml @@ -0,0 +1,11 @@ +epoch: 50 + +batch_size: 64 +val_per_epoch: 1 + +buffer: # By default Buffer is not used, set buffer_size to 0 + name: LinearBuffer + kwargs: + buffer_size: 0 + batch_size: 128 + strategy: herding # random, equal_random, reservoir, herding \ No newline at end of file diff --git a/config/headers/optimizer.yaml b/config/headers/optimizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01720d65eb071eb607f9bf9eb22902686b8e2161 --- /dev/null +++ b/config/headers/optimizer.yaml @@ -0,0 +1,11 @@ +# optimizer info +optimizer: + name: SGD + kwargs: + lr: 0.1 + +# lr_scheduler info +lr_scheduler: # By Default, No LR Scheduler is used + name: Constant + +warmup: 3 \ No newline at end of file diff --git a/config/headers/test.yaml b/config/headers/test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11f79461dabd04d04456b5b8a6e74e75924317b2 --- /dev/null +++ b/config/headers/test.yaml @@ -0,0 +1,11 @@ +# Testing Settings +testing_times: 10 # Take average of 10 testings +setting: task-agnostic # Task ID is not provided during inference +# or task-aware, provide Task ID during inference + +testing_per_task: True # Test data comes in per task (Each batch of test data will belongs to same task) +# False, each batch of test data is not assure to be in same task +# Not yet Implemented + +eval_with_test: True # Use Test data for in-epoch validation, If False +# Another eval dateset will seperated from train dataset for in-epoch validation \ No newline at end of file diff --git a/config/icarl.yaml b/config/icarl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3750fa807af6738bae4838d34d1a5eb54278cadb --- /dev/null +++ b/config/icarl.yaml @@ -0,0 +1,68 @@ +includes: + - headers/data.yaml + - headers/device.yaml + # - headers/model.yaml + # - headers/optimizer.yaml + # - backbones/resnet12.yaml + +# data_root: /data/fanzhichen/continual/cifar100 +data_root: ~/datasets/cifar100 +image_size: 32 + + +warmup: 3 + + +save_path: ./ +# data +init_cls_num: 10 +inc_cls_num: 10 +task_num: 10 + + +batch_size: 128 + +init_epoch: 200 #100 +epoch: 170 #100 + +device_ids: 2 +n_gpu: 1 +val_per_epoch: 1 + + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: 100 + + +backbone: + name: cifar_resnet32 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + + +buffer: + name: LinearHerdingBuffer + kwargs: + buffer_size: 2000 + batch_size: 64 + # strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: ICarl + kwargs: + num_class: 100 + feat_dim: 64 + init_cls_num: 10 + inc_cls_num: 10 + task_num: 10 \ No newline at end of file diff --git a/config/icarl_5dataset.yaml b/config/icarl_5dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7f9fa048020739735fd238c41f1c5d68c56f34b --- /dev/null +++ b/config/icarl_5dataset.yaml @@ -0,0 +1,72 @@ +includes: + - headers/data.yaml + - headers/device.yaml + # - headers/model.yaml + # - headers/optimizer.yaml + # - backbones/resnet12.yaml + +warmup: 0 + + +dataset: &dataset 5-datasets +data_root: /data/Dataset/5-dataset +class_order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] +testing_times: 1 + +save_path: ./ +# data +init_cls_num: 10 +inc_cls_num: 10 +task_num: 5 + + +batch_size: 128 + +init_epoch: 1 #100 +epoch: 1 #100 + +device_ids: 2 +n_gpu: 1 +val_per_epoch: 1 + + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: 100 + + +backbone: + name: cifar_resnet32 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + + +buffer: + name: LinearHerdingBuffer + kwargs: + buffer_size: 2000 + batch_size: 64 + # strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: ICarl + kwargs: + num_class: 50 + feat_dim: 64 + init_cls_num: 10 + inc_cls_num: 10 + task_num: 5 \ No newline at end of file diff --git a/config/l2p-vit-cifar100-b10-10-10.yaml b/config/l2p-vit-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac7c9292442e34c500bf0f382ddd330b4ac0aad1 --- /dev/null +++ b/config/l2p-vit-cifar100-b10-10-10.yaml @@ -0,0 +1,67 @@ + +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num +epoch: 1 # 5 +val_per_epoch: 5 + +batch_size: 16 # Source code is 16 per device * 8 devices, since we don't use distribution device, set batch_size to 16 +testing_times: 1 + +seed: 2 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.3333] # [0.75, 1.3333333333] + interpolation: BILINEAR + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 # Stated in source code of L2P + interpolation: BICUBIC # 3 # LANCZOS + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: Adam + kwargs: + lr: 0.001875 # 0.03 + betas: [0.9, 0.999] + weight_decay: 0 + +lr_scheduler: + name: Constant + +backbone: + name: vit_pt_imnet + kwargs: + num_classes: 100 + pretrained: true + model_name : vit_base_patch16_224 + +classifier: + name: L2P + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + num_class: *total_cls_num + task_num: *task_num + feat_dim: 768 + prompt_length: 5 # L_p in paper + pool_size: 10 # M in paper + top_k: 5 # N in paper + pull_constraint_coeff: 1.0 # -0.5 in paper, 1.0 in source code + diff --git a/config/lwf.yaml b/config/lwf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..59a8d8834ac3de918829f3308f53c0da5b0ee989 --- /dev/null +++ b/config/lwf.yaml @@ -0,0 +1,42 @@ +image_size: 32 + +# data +init_cls_num: 20 # 10 +inc_cls_num: 20 # 10 +task_num: 5 # 10 + +epoch: 100 # 100 +n_gpu: 1 +val_per_epoch: 10 + +batch_size: 128 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +lr_scheduler: + name: StepLR + kwargs: + gamma: 0.5 + step_size: 30 + +backbone: + name: resnet34 # resnet18 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + +classifier: + name: LWF + kwargs: + num_class: 100 + feat_dim: 512 + init_cls_num: 20 #10 + inc_cls_num: 20 #10 + dist: 0.5 + lamda: 10 + K: 2 + lw_mr: 1 diff --git a/config/ranpac.yaml b/config/ranpac.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30427a88a389f51e04c165798a9f898115a4c119 --- /dev/null +++ b/config/ranpac.yaml @@ -0,0 +1,63 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +init_epoch: 20 # 20 +epoch: 1 # 1 +batch_size: 48 # 128 +val_per_epoch: 20 + +seed: 2 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.33333333] # [3./4., 4./3.] + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: *image_size + interpolation: BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + momentum: 0.9 + lr: 0.01 + weight_decay: 0.0005 + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: 20 + eta_min: 0.0 + +backbone: + name: vit_pt_imnet_in21k_adapter + kwargs: + pretrained: true + model_name : vit_base_patch16_224_in21k + +classifier: + name: RanPAC + kwargs: + use_RP: True + M: 10000 + first_session_training: True + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + total_cls_num: *total_cls_num diff --git a/config/ranpac_clip.yaml b/config/ranpac_clip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8dac3849b960485fd472c67b3581842d3c963179 --- /dev/null +++ b/config/ranpac_clip.yaml @@ -0,0 +1,66 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +init_epoch: 20 # 20 +epoch: 1 # 1 +batch_size: 48 # 128 +val_per_epoch: 20 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.33333333] # [3./4., 4./3.] + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: *image_size + interpolation: BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 1 + step: -1 # think again + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: RanPAC + kwargs: + use_RP: True + M: 10000 + first_session_training: True + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + total_cls_num: *total_cls_num + prompt_template : "a bad photo of a {}." diff --git a/config/rapf10-10.yaml b/config/rapf10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae5ba3c38f9601912aa8a932b50f66f141fbe3b4 --- /dev/null +++ b/config/rapf10-10.yaml @@ -0,0 +1,74 @@ +includesL: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + +data_root: /home/xtx/datasets/cifar100 +image_size: &image_size 224 +num_workers: &num_workers 16 + +save_path: ./ + +seed: &seed 1993 + +is_rapf: True + + +# Control B and +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 10 + +epoch: &epoch 15 +batch_size: &batch_size 128 +train_batch_size: &train_batch_size 100 +n_gpu: 1 +beta: &beta 2 +shrinkage: &shrinkage False +threshold: &threshold 0.55 +val_per_epoch: &val_per_epoch 10 + +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0000 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [4, 10] + last_epoch: -1 + +backbone: + name: clip + kwargs: + model_name: ViT-B/16 + device: cuda + experts_num: 1 + block_layer: ResidualAttentionBlock_MoE_MLP + top_k : 1 + step: 1 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: RAPF + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + threshold: *threshold + beta: *beta + shrinkage: *shrinkage + train_batch_size: *train_batch_size + batch_size: *batch_size + num_workers: *num_workers + prompt_template: "a good photo of a {}" + seed: *seed + fp16: False + # class_order: *class_order + mix_bias: 0.6 + feat_dim: 64 + num_class: 100 + diff --git a/config/rapf50-10.yaml b/config/rapf50-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f7c3605a017edefc96ff4291dabd7f04e6f100e9 --- /dev/null +++ b/config/rapf50-10.yaml @@ -0,0 +1,72 @@ +includesL: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + +data_root: /home/xtx/datasets/cifar100 +image_size: &image_size 224 +num_workers: &num_workers 16 + +save_path: ./ + +class_order: &class_order [87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, 34, 55, 54, 26, 35, 39] +seed: &seed 1919810 + +is_rapf: True + + +# Control B and +init_cls_num: &init_cls_num 50 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 6 + +epoch: &epoch 15 +batch_size: &batch_size 128 +train_batch_size: &train_batch_size 100 +n_gpu: 1 +beta: &beta 2 +shrinkage: &shrinkage False +threshold: &threshold 0.55 +val_per_epoch: &val_per_epoch 10 + +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0000 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [4, 10] + last_epoch: -1 + +backbone: + name: clip + kwargs: + model_name: ViT-B/16 + device: cuda + experts_num: 1 + block_layer: ResidualAttentionBlock_MoE_MLP + top_k : 1 + step: 1 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: RAPF + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + threshold: *threshold + beta: *beta + shrinkage: *shrinkage + train_batch_size: *train_batch_size + batch_size: *batch_size + num_workers: *num_workers + prompt_template: "a good photo of a {}" + seed: *seed + fp16: False + class_order: *class_order + mix_bias: 0.6 diff --git a/config/rapf50-5.yaml b/config/rapf50-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3953f1a4afc54ef64675df2511c868b21467617f --- /dev/null +++ b/config/rapf50-5.yaml @@ -0,0 +1,71 @@ +includesL: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + +data_root: /home/xtx/datasets/cifar100 +image_size: &image_size 224 +num_workers: &num_workers 16 + +save_path: ./ + +seed: &seed 1919810 + +is_rapf: True + + +# Control B and +init_cls_num: &init_cls_num 50 +inc_cls_num: &inc_cls_num 5 +task_num: &task_num 11 + +epoch: &epoch 15 +batch_size: &batch_size 128 +train_batch_size: &train_batch_size 100 +n_gpu: 1 +beta: &beta 2 +shrinkage: &shrinkage False +threshold: &threshold 0.65 +val_per_epoch: &val_per_epoch 10 + +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0000 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [4, 10] + last_epoch: -1 + +backbone: + name: clip + kwargs: + model_name: ViT-B/16 + device: cuda + experts_num: 1 + block_layer: ResidualAttentionBlock_MoE_MLP + top_k : 1 + step: 1 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: RAPF + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + threshold: *threshold + beta: *beta + shrinkage: *shrinkage + train_batch_size: *train_batch_size + batch_size: *batch_size + num_workers: *num_workers + prompt_template: "a good photo of a {}" + seed: *seed + fp16: False + class_order: *class_order + mix_bias: 0.6 \ No newline at end of file diff --git a/config/tam.yaml b/config/tam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1e8fa7fd543102168718bcea485efef09f21080 --- /dev/null +++ b/config/tam.yaml @@ -0,0 +1,70 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + # - headers/optimizer.yaml + - backbones/resnet12.yaml + +data_root: /data/fanzhichen/continual/cifar100 +image_size: 32 + + +save_path: ./ +# data +init_cls_num: 20 +inc_cls_num: 20 +task_num: 5 + + +epoch: 0 # 160 +device_ids: 4 +n_gpu: 1 +val_per_epoch: 1 + + +batch_size: 128 + + +optimizer: + name: SGD + kwargs: + lr: 0.03 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [60, 120, 170] + +backbone: + name: resnet18 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + + +buffer: + name: LinearBuffer + kwargs: + buffer_size: 1000 + batch_size: 128 + strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: TAM + kwargs: + num_class: 100 + feat_dim: 512 + init_cls_num: 20 + inc_cls_num: 20 + lamda: 1000 + reg_weight: 0.1 + ema_update_freq: 0.05 + ema_alpha: 0.999 + pairwise_weight: 0.1 + alpha: 0.2 + beta: 0.5 + code_dims: 64 \ No newline at end of file diff --git a/config/zz_BIC/bic-resnet32-5dataset-b10-10-5.yaml b/config/zz_BIC/bic-resnet32-5dataset-b10-10-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14ea5bbbf8ac5872144324d04105b9ce61233e2f --- /dev/null +++ b/config/zz_BIC/bic-resnet32-5dataset-b10-10-5.yaml @@ -0,0 +1,65 @@ +dataset: &dataset 5-datasets +data_root: /data/Dataset/5-dataset +class_order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + +total_cls_num: &total_cls_num 50 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 5 +image_size: &image_size 32 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 250 +stage2_epoch: 250 +val_per_epoch: 50 +batch_size: 128 + +testing_times: 1 + +seed: 0 + +num_workers: 0 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 2e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [100, 150, 200] + +# done +backbone: + name: cifar_resnet32_V2 # cifar_resnet32_V2 , resnet32 for dataset cifar100, see original paper + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: LinearSpiltBuffer + kwargs: + buffer_size: 1000 + batch_size: 128 + strategy: balance_random # random, equal_random, reservoir, herding + val_ratio: 0.1 + +classifier: + name: bic + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num \ No newline at end of file diff --git a/config/zz_BIC/bic-resnet32-cifar100-b10-10-10.yaml b/config/zz_BIC/bic-resnet32-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21f727c0cd2d5d6fd42ba0603cb0253d1e01d6f2 --- /dev/null +++ b/config/zz_BIC/bic-resnet32-cifar100-b10-10-10.yaml @@ -0,0 +1,77 @@ + +dataset: &dataset binary_cifar100 +data_root: /home/lvqiexuan/temp_data/binary_cifar100 + +total_cls_num: &total_cls_num 100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 10 +image_size: &image_size 32 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 1 # normally 250, 1 for online setting +stage2_epoch: 1 # normally 250, 1 for online setting +val_per_epoch: 50 +batch_size: 10 # normally 128, 10 for online setting + +testing_times: 1 + +seed: 1993 + +num_workers: 0 + +train_trfms: + - RandomHorizontalFlip: {} + - RandomCrop: + size : *image_size + padding : 4 + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] + std: [0.2673, 0.2564, 0.2762] + +test_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] + std: [0.2673, 0.2564, 0.2762] + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 2e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [100, 150, 200] + +# done +backbone: + name: cifar_resnet32_V2 # cifar_resnet32_V2 , resnet32 for dataset cifar100, see original paper + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: LinearSpiltBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: balance_random # random, equal_random, reservoir, herding + val_ratio: 0.1 + +classifier: + name: bic + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num \ No newline at end of file diff --git a/config/zz_BIC/bic-resnet32-cifar100-b20-20-5.yaml b/config/zz_BIC/bic-resnet32-cifar100-b20-20-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..742e13167df994192c54feea45db1d271927f980 --- /dev/null +++ b/config/zz_BIC/bic-resnet32-cifar100-b20-20-5.yaml @@ -0,0 +1,77 @@ + +dataset: &dataset binary_cifar100 +data_root: /home/lvqiexuan/temp_data/binary_cifar100 + +total_cls_num: &total_cls_num 100 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 5 +image_size: &image_size 32 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 250 +stage2_epoch: 250 +val_per_epoch: 50 +batch_size: 128 + +testing_times: 5 + +seed: 2 + +num_workers: 0 + +train_trfms: + - RandomHorizontalFlip: {} + - RandomCrop: + size : *image_size + padding : 4 + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] + std: [0.2673, 0.2564, 0.2762] + +test_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] + std: [0.2673, 0.2564, 0.2762] + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 2e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [100, 150, 200] + +# done +backbone: + name: cifar_resnet32_V2 # cifar_resnet32_V2 , resnet32 for dataset cifar100, see original paper + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: LinearSpiltBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: balance_random # random, equal_random, reservoir, herding + val_ratio: 0.1 + +classifier: + name: bic + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num \ No newline at end of file diff --git a/config/zz_BIC/bic-resnet32-imagenetr-b20-20-10.yaml b/config/zz_BIC/bic-resnet32-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a5cb4117cbd07b90093d87d9c9412457eccc00bb --- /dev/null +++ b/config/zz_BIC/bic-resnet32-imagenetr-b20-20-10.yaml @@ -0,0 +1,59 @@ +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 1 #250 +stage2_epoch: 1 #250 +val_per_epoch: 50 +batch_size: 10 #128 + +testing_times: 1 + +seed: 0 + +num_workers: 0 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 2e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [100, 150, 200] + +# done +backbone: + name: cifar_resnet32_V2 # cifar_resnet32_V2 , resnet32 for dataset cifar100, see original paper + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: LinearSpiltBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: balance_random # random, equal_random, reservoir, herding + val_ratio: 0.1 + +classifier: + name: bic + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num \ No newline at end of file diff --git a/config/zz_BIC/bic-resnet32-tiny-b20-20-10.yaml b/config/zz_BIC/bic-resnet32-tiny-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd34032c59e8db96c35c189c8e37f584634624fd --- /dev/null +++ b/config/zz_BIC/bic-resnet32-tiny-b20-20-10.yaml @@ -0,0 +1,59 @@ +dataset: &dataset tiny-imagenet +data_root: /data/Dataset/ + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 1 #250 +stage2_epoch: 1 #250 +val_per_epoch: 50 +batch_size: 10 #128 + +testing_times: 1 + +seed: 0 + +num_workers: 0 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 2e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [100, 150, 200] + +# done +backbone: + name: cifar_resnet32_V2 # cifar_resnet32_V2 , resnet32 for dataset cifar100, see original paper + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: LinearSpiltBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: balance_random # random, equal_random, reservoir, herding + val_ratio: 0.1 + +classifier: + name: bic + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num \ No newline at end of file diff --git a/config/zz_CL-LoRA/cl_lora-cifar100-b5-5-20.yaml b/config/zz_CL-LoRA/cl_lora-cifar100-b5-5-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6a2970e8327d980e69d995738f8ed905dcdff69 --- /dev/null +++ b/config/zz_CL-LoRA/cl_lora-cifar100-b5-5-20.yaml @@ -0,0 +1,75 @@ +# cl_lora-cifar100-b5-5-20 + +dataset: &dataset cifar100 +data_root: /data/lqx/cifar100 +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 20 +image_size: &image_size 224 +epoch: &epoch 30 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 64 + +setting: task-agnostic + +testing_times: 1 +testing_per_task: False + +seed: 1993 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.3333] + - RandomHorizontalFlip: {} + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 + interpolation: 3 # BICUBIC + - CenterCrop: + size: [*image_size, *image_size] + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 0.03 + momentum: 0.9 + weight_decay: 0.0001 + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: *epoch + eta_min: 0 + +backbone: + name: vit_cl_lora + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_CL_LoRA + transformer_layer: Transformer_CL_LoRA + lora_rank: 8 + norm_layer_eps: 1e-6 + +classifier: + name: CL_LoRA + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 \ No newline at end of file diff --git a/config/zz_CL-LoRA/cl_lora-imagenetr-b5-5-40.yaml b/config/zz_CL-LoRA/cl_lora-imagenetr-b5-5-40.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5f7bd30382010403db02b94504fede46edbde1d --- /dev/null +++ b/config/zz_CL-LoRA/cl_lora-imagenetr-b5-5-40.yaml @@ -0,0 +1,76 @@ + + +dataset: &dataset imagenet-r +data_root: /data/lqx/imagenet-r + +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 200 +task_num: &task_num 40 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 32 +seed: 1993 + +setting: task-agnostic + +testing_times: 1 +testing_per_task: False + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.3333] + - RandomHorizontalFlip: {} + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 + interpolation: 3 # BICUBIC + - CenterCrop: + size: [*image_size, *image_size] + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 0.05 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: *epoch + eta_min: 0 + +backbone: + name: vit_cl_lora + kwargs: + pretrained: True + model_name : vit_base_patch16_224 + attn_layer: MultiHeadAttention_CL_LoRA + transformer_layer: Transformer_CL_LoRA + lora_rank: 10 + norm_layer_eps: 1e-6 + +classifier: + name: CL_LoRA + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + \ No newline at end of file diff --git a/config/zz_ERACE/erace-resnet18-5dataset-b10-10-5.yaml b/config/zz_ERACE/erace-resnet18-5dataset-b10-10-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fda47fe006cfe7a08984b587f1663359e7e72431 --- /dev/null +++ b/config/zz_ERACE/erace-resnet18-5dataset-b10-10-5.yaml @@ -0,0 +1,46 @@ + + +dataset: &dataset 5-datasets +data_root: /data/Dataset/5-dataset +class_order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + +total_cls_num: &total_cls_num 50 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 5 +image_size: &image_size 32 + +epoch: 1 +seed: 2 + +batch_size: 10 +testing_times: 1 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +buffer: + name: ERBuffer + kwargs: + capacity: 10000 # num_classes * M + +backbone: + name: resnet18_AML + kwargs: + dataset: *dataset + num_classes: *total_cls_num + +classifier: + name: ERACE + kwargs: + num_classes: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_free: True + use_augs: False diff --git a/config/zz_ERACE/erace-resnet18-imagenetr-b20-20-10.yaml b/config/zz_ERACE/erace-resnet18-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c2577e7c7cd7aaa47e6e406a37cbce07e64f9b3 --- /dev/null +++ b/config/zz_ERACE/erace-resnet18-imagenetr-b20-20-10.yaml @@ -0,0 +1,41 @@ + +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +image_size: &image_size 224 + +epoch: 1 +seed: 2 + +batch_size: 10 +testing_times: 1 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +buffer: + name: ERBuffer + kwargs: + capacity: 20000 # num_classes * M + +backbone: + name: resnet18_AML + kwargs: + dataset: *dataset + num_classes: *total_cls_num + input_size: [3, *image_size, *image_size] + +classifier: + name: ERACE + kwargs: + num_classes: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_free: True + use_augs: False diff --git a/config/zz_ERACE/erace.yaml b/config/zz_ERACE/erace.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a61fb347831b397e88c5d473780ce67ddb3ef60b --- /dev/null +++ b/config/zz_ERACE/erace.yaml @@ -0,0 +1,58 @@ + +dataset: &dataset binary_cifar100 +data_root: /home/lvqiexuan/temp_data/binary_cifar100 + +total_cls_num: &total_cls_num 100 +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +task_num: &task_num 20 +image_size: &image_size 32 + +epoch: 1 +seed: 2 + +batch_size: 10 +testing_times: 1 + +train_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + - RandomCrop: + size: *image_size + padding: 4 + padding_mode: constant + fill: -1 + - RandomHorizontalFlip: {} + +test_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +buffer: + name: ERBuffer + kwargs: + capacity: 10000 # num_classes * M + +backbone: + name: resnet18_AML + kwargs: + dataset: *dataset + num_classes: *total_cls_num + +classifier: + name: ERACE + kwargs: + num_classes: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_free: True + use_augs: False diff --git a/config/zz_ERAML/eraml-resnet18-5dataset-b10-10-5.yaml b/config/zz_ERAML/eraml-resnet18-5dataset-b10-10-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26637d89e3254eb4f2a250efb8f6e082bcf8d030 --- /dev/null +++ b/config/zz_ERAML/eraml-resnet18-5dataset-b10-10-5.yaml @@ -0,0 +1,46 @@ +dataset: &dataset 5-datasets +data_root: /data/Dataset/5-dataset +class_order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + + +total_cls_num: &total_cls_num 50 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 5 +image_size: &image_size 32 + +epoch: 1 + +batch_size: 10 +testing_times: 1 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +buffer: + name: ERBuffer + kwargs: + capacity: 10000 # num_classes * M + +backbone: + name: resnet18_AML + kwargs: + dataset: *dataset + num_classes: *total_cls_num + +classifier: + name: ERAML + kwargs: + num_classes: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_free: True + use_augs: False + supcon_temperature: 0.2 + use_minimal_selection: False diff --git a/config/zz_ERAML/eraml-resnet18-imagenetr-b20-20-10.yaml b/config/zz_ERAML/eraml-resnet18-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..645b9c1ce4bab7b6cf0886314c9236609ddc732a --- /dev/null +++ b/config/zz_ERAML/eraml-resnet18-imagenetr-b20-20-10.yaml @@ -0,0 +1,41 @@ +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +image_size: &image_size 224 + +epoch: 1 + +batch_size: 10 +testing_times: 1 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +buffer: + name: ERBuffer + kwargs: + capacity: 20000 # num_classes * M + +backbone: + name: resnet18_AML + kwargs: + dataset: *dataset + num_classes: *total_cls_num + input_size: [3, *image_size, *image_size] + +classifier: + name: ERAML + kwargs: + num_classes: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_free: True + use_augs: False + supcon_temperature: 0.2 + use_minimal_selection: False diff --git a/config/zz_ERAML/eraml.yaml b/config/zz_ERAML/eraml.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6b024d83fb1d8e38ebc0a179c3326595a746c91 --- /dev/null +++ b/config/zz_ERAML/eraml.yaml @@ -0,0 +1,58 @@ +dataset: &dataset binary_cifar100 +total_cls_num: &total_cls_num 100 +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +task_num: &task_num 20 +image_size: &image_size 32 + +data_root: /home/lvqiexuan/temp_data/binary_cifar100 + +epoch: 1 + +batch_size: 10 +testing_times: 1 + +train_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + - RandomCrop: + size: *image_size + padding: 4 + padding_mode: constant + fill: -1 + - RandomHorizontalFlip: {} + +test_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + +optimizer: + name: SGD + kwargs: + lr: 0.1 + +buffer: + name: ERBuffer + kwargs: + capacity: 10000 # num_classes * M + +backbone: + name: resnet18_AML + kwargs: + dataset: *dataset + num_classes: *total_cls_num + +classifier: + name: ERAML + kwargs: + num_classes: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_free: True + use_augs: False + supcon_temperature: 0.2 + use_minimal_selection: False diff --git a/config/zz_GPM/gpm_cil-alexnet-cifar100-b10-10-10.yaml b/config/zz_GPM/gpm_cil-alexnet-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e421fc8c50fb6350c2e6aff588ba206a963bc857 --- /dev/null +++ b/config/zz_GPM/gpm_cil-alexnet-cifar100-b10-10-10.yaml @@ -0,0 +1,38 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_cil-alexnet-cifar100-b2-2-50.yaml b/config/zz_GPM/gpm_cil-alexnet-cifar100-b2-2-50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc94970c6f5466e4545bf67351c740f37930fe6e --- /dev/null +++ b/config/zz_GPM/gpm_cil-alexnet-cifar100-b2-2-50.yaml @@ -0,0 +1,39 @@ +init_cls_num: &init_cls_num 2 +inc_cls_num: &inc_cls_num 2 +total_cls_num: &total_cls_num 100 +task_num: &task_num 50 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 # 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +seed: 1993 +testing_times: 1 # Don't set too high, it will take eternity + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_cil-alexnet-cifar100-b20-20-5.yaml b/config/zz_GPM/gpm_cil-alexnet-cifar100-b20-20-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f4236030cdcb13b386139ee7012d3418a75055c --- /dev/null +++ b/config/zz_GPM/gpm_cil-alexnet-cifar100-b20-20-5.yaml @@ -0,0 +1,38 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_cil-alexnet-cifar100-b5-5-20.yaml b/config/zz_GPM/gpm_cil-alexnet-cifar100-b5-5-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d2ce08b19eeb00aeb35d4a3873e0df4b18c1fff --- /dev/null +++ b/config/zz_GPM/gpm_cil-alexnet-cifar100-b5-5-20.yaml @@ -0,0 +1,38 @@ +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 20 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_cil-alexnet-tiny-b100-10-11.yaml b/config/zz_GPM/gpm_cil-alexnet-tiny-b100-10-11.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72f480ac629293bffb04ea40bbddadfd4b4eecda --- /dev/null +++ b/config/zz_GPM/gpm_cil-alexnet-tiny-b100-10-11.yaml @@ -0,0 +1,41 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 200 +task_num: &task_num 11 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_cil-alexnet-tiny-b100-20-6.yaml b/config/zz_GPM/gpm_cil-alexnet-tiny-b100-20-6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32bb0bcd90b3e2b666b5291b96c5bc56a9edad07 --- /dev/null +++ b/config/zz_GPM/gpm_cil-alexnet-tiny-b100-20-6.yaml @@ -0,0 +1,41 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 200 +task_num: &task_num 6 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_cil-alexnet-tiny-b100-5-21.yaml b/config/zz_GPM/gpm_cil-alexnet-tiny-b100-5-21.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6cbb13cb7699b3835e35447f68a84f4e0836261 --- /dev/null +++ b/config/zz_GPM/gpm_cil-alexnet-tiny-b100-5-21.yaml @@ -0,0 +1,41 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 200 +task_num: &task_num 21 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_til-alexnet-cifar100-b10-10-10.yaml b/config/zz_GPM/gpm_til-alexnet-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa854aa099c578e1a75f050deeba845f37e77567 --- /dev/null +++ b/config/zz_GPM/gpm_til-alexnet-cifar100-b10-10-10.yaml @@ -0,0 +1,39 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +init_epoch: 200 +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-aware # [task-aware, task-agnostic] +seed: 2 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_til-alexnet-cifar100-b20-20-5.yaml b/config/zz_GPM/gpm_til-alexnet-cifar100-b20-20-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4e22deb52e9175a798382686fcf5a06565ccb081 --- /dev/null +++ b/config/zz_GPM/gpm_til-alexnet-cifar100-b20-20-5.yaml @@ -0,0 +1,39 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +init_epoch: 200 # 200 +epoch: 200 # 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-aware # [task-aware, task-agnostic] +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_til-alexnet-tiny-b100-20-6.yaml b/config/zz_GPM/gpm_til-alexnet-tiny-b100-20-6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c2e87f9cc9924229dfa1a662a0ad199dddd09f5 --- /dev/null +++ b/config/zz_GPM/gpm_til-alexnet-tiny-b100-20-6.yaml @@ -0,0 +1,41 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 200 +task_num: &task_num 6 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-aware +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_GPM/gpm_til-alexnet-tiny-b100-5-21.yaml b/config/zz_GPM/gpm_til-alexnet-tiny-b100-5-21.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38983bdb2df85a3d4b66261ca4d637ad7046fe5d --- /dev/null +++ b/config/zz_GPM/gpm_til-alexnet-tiny-b100-5-21.yaml @@ -0,0 +1,41 @@ +init_cls_num: &init_cls_num 40 +inc_cls_num: &inc_cls_num 40 +total_cls_num: &total_cls_num 200 +task_num: &task_num 5 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-aware +seed: 1993 + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: GPM + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num diff --git a/config/zz_LUCIR/lucir-resnet32-imagenetr-b20-20-10.yaml b/config/zz_LUCIR/lucir-resnet32-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae0ff9a1d84db7ed54fbb52af20a8da0b89920a2 --- /dev/null +++ b/config/zz_LUCIR/lucir-resnet32-imagenetr-b20-20-10.yaml @@ -0,0 +1,60 @@ + +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r + +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 200 +task_num: &task_num 10 +image_size: &image_size 224 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 1 +val_per_epoch: 160 +batch_size: 10 + +testing_times: 1 + +seed: 2 + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [80, 120] + +backbone: + name: resnet32_V2 + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: LinearBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: LUCIR + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + feat_dim: 153664 + dist: 0.5 + lamda: 5 + K: 2 + lw_mr: 1 diff --git a/config/zz_LUCIR/lucir.yaml b/config/zz_LUCIR/lucir.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f53b3ac33a19a71fdf9dd02db4f5b062d61ac9cb --- /dev/null +++ b/config/zz_LUCIR/lucir.yaml @@ -0,0 +1,75 @@ + +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 50 +inc_cls_num: &inc_cls_num 5 # 10 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 11 # 6 11 +image_size: &image_size 32 + +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num + +epoch: 1 +val_per_epoch: 160 +train_batch_size: 128 +test_batch_size: 100 + +testing_times: 1 + +seed: 2 + +train_trfms: + - RandomCrop : + size : *image_size + padding : 4 + - RandomHorizontalFlip : {} + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] + std: [0.2009, 0.1984, 0.2023] + +test_trfms: + - ToTensor: {} + - Normalize: + mean: [0.5071, 0.4866, 0.4409] + std: [0.2009, 0.1984, 0.2023] + +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [80, 120] + +backbone: + name: resnet32_V2 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + +buffer: + name: LinearBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: LUCIR + kwargs: + num_class: *total_cls_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + feat_dim: 64 + dist: 0.5 + lamda: 5 + K: 2 + lw_mr: 1 diff --git a/config/zz_LoRA-Sub-DRS/lora_sub-cifar100-b10-10-10.yaml b/config/zz_LoRA-Sub-DRS/lora_sub-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fafde09acf594dfb73f3c8ee355dc9f2b0769f82 --- /dev/null +++ b/config/zz_LoRA-Sub-DRS/lora_sub-cifar100-b10-10-10.yaml @@ -0,0 +1,73 @@ +# lora_sub-cifar100-b10-10-10 + +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: 20 +val_per_epoch: 20 + +batch_size: 128 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.08, 1.0] + ratio: [0.75, 1.3333] + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_LoRA_Sub + lora_rank: 10 + +classifier: + name: LoRAsub_DRS + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + fc_lrate: 0.002 + margin_inter: 1.0 + lambada: 0.05 \ No newline at end of file diff --git a/config/zz_LoRA-Sub-DRS/lora_sub-cifar100-b5-5-20.yaml b/config/zz_LoRA-Sub-DRS/lora_sub-cifar100-b5-5-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35b1428e3ea11bae7af76a919c6c071291c9aad7 --- /dev/null +++ b/config/zz_LoRA-Sub-DRS/lora_sub-cifar100-b5-5-20.yaml @@ -0,0 +1,73 @@ +# lora_sub-cifar100-b5-5-20 + +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 20 +image_size: &image_size 224 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: 20 +val_per_epoch: 20 + +batch_size: 128 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.08, 1.0] + ratio: [0.75, 1.3333] + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_LoRA_Sub + lora_rank: 10 + +classifier: + name: LoRAsub_DRS + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + fc_lrate: 0.002 + margin_inter: 1.0 + lambada: 0.05 \ No newline at end of file diff --git a/config/zz_LoRA-Sub-DRS/lora_sub-imgnr-b10-10-20.yaml b/config/zz_LoRA-Sub-DRS/lora_sub-imgnr-b10-10-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67194240ce976b47b6f3ed5b6757252cdd346877 --- /dev/null +++ b/config/zz_LoRA-Sub-DRS/lora_sub-imgnr-b10-10-20.yaml @@ -0,0 +1,78 @@ +# lora_sub-imgnr-b10-10-20 + +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 200 +task_num: &task_num 20 +image_size: &image_size 224 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +init_epoch: 40 +epoch: 50 +val_per_epoch: 50 + +batch_size: 128 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.08, 1.0] + ratio: [0.75, 1.3333] + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: 256 + - CenterCrop: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_LoRA_Sub + lora_rank: 10 + +classifier: + name: LoRAsub_DRS + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + fc_lrate: 0.0005 + margin_inter: 2.0 + lambada: 0.2 + \ No newline at end of file diff --git a/config/zz_LoRA-Sub-DRS/lora_sub-imgnr-b20-20-10.yaml b/config/zz_LoRA-Sub-DRS/lora_sub-imgnr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96a32ccfdc9b6163846b9ec68e4faf84f097f25a --- /dev/null +++ b/config/zz_LoRA-Sub-DRS/lora_sub-imgnr-b20-20-10.yaml @@ -0,0 +1,78 @@ +# lora_sub-imgnr-b20-20-10 + +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 200 +task_num: &task_num 10 +image_size: &image_size 224 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +init_epoch: 40 +epoch: 50 +val_per_epoch: 50 + +batch_size: 128 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.08, 1.0] + ratio: [0.75, 1.3333] + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: 256 + - CenterCrop: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: 20 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_LoRA_Sub + lora_rank: 10 + +classifier: + name: LoRAsub_DRS + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + fc_lrate: 0.0005 + margin_inter: 2.0 + lambada: 0.2 + \ No newline at end of file diff --git a/config/zz_MInfLoRA/MInfLoRA-vit-cifar100-b10-10-10.yaml b/config/zz_MInfLoRA/MInfLoRA-vit-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a15067e8167b086459fecdcd6267d6e3f12f6c21 --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA-vit-cifar100-b10-10-10.yaml @@ -0,0 +1,73 @@ +# Dataset CiFar100 +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 10 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch + +seed: 1993 + +batch_size: 128 +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0.0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: *epoch + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MaskedLoRA1 + block_layer: ResidualAttentionBiBlock # Bi + transformer_layer: Transformer_Proj + lora_rank: 10 + +classifier: + name: MInfLoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False + eval_mat: False \ No newline at end of file diff --git a/config/zz_MInfLoRA/MInfLoRA-vit-imagenetr-b20-20-10.yaml b/config/zz_MInfLoRA/MInfLoRA-vit-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f5ff27e038c27b9236d60a5aeae29a0dd87ceac --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA-vit-imagenetr-b20-20-10.yaml @@ -0,0 +1,73 @@ + +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch + +workers: 24 + +seed: 1993 + +batch_size: 128 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.333] + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 + interpolation: BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 8e-3 + momentum: 0.9 + +lr_scheduler: + name: Constant + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MaskedLoRA1 + block_layer: ResidualAttentionBiBlock # Bi + transformer_layer: Transformer_Proj + lora_rank: 10 + +classifier: + name: MInfLoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False diff --git a/config/zz_MInfLoRA/MInfLoRA.yaml b/config/zz_MInfLoRA/MInfLoRA.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3cb23ef95d240e9605a774dab019c3e7a77464aa --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA.yaml @@ -0,0 +1,71 @@ +# Dataset CiFar100 +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 10 +epoch: &epoch 20 +image_size: &image_size 224 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 128 +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0.0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: *epoch + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MaskedLoRA1 + block_layer: ResidualAttentionBiBlock # Bi + transformer_layer: Transformer_Proj + lora_rank: 10 + +classifier: + name: MInfLoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False + eval_mat: False \ No newline at end of file diff --git a/config/zz_MInfLoRA/MInfLoRA2.yaml b/config/zz_MInfLoRA/MInfLoRA2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a396828abf1e4ab5f6be8fca1cadb75d0faac79e --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA2.yaml @@ -0,0 +1,71 @@ +# Dataset CiFar100 +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 10 +epoch: &epoch 20 +image_size: &image_size 224 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 128 +testing_times: 1 #10 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0.0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: *epoch + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MultiMaskedLoRA + block_layer: ResidualAttentionBiBlock # Bi + transformer_layer: Transformer_Proj + lora_rank: 10 + +classifier: + name: MInfLoRA2 + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False + eval_mat: True \ No newline at end of file diff --git a/config/zz_MInfLoRA/MInfLoRA2_imgnr.yaml b/config/zz_MInfLoRA/MInfLoRA2_imgnr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36fa6aa2cc8d0f69d85bae5d107cfe8c9d0e82c7 --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA2_imgnr.yaml @@ -0,0 +1,57 @@ +# Dataset ImageNet-R +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +epoch: &epoch 10 # 50 +image_size: &image_size 224 + +image_size: *image_size +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch +workers: 24 + +batch_size: 128 +testing_times: 1 #10 + +setting: task-agnostic + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0.0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: *epoch + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MultiMaskedLoRA + block_layer: ResidualAttentionBiBlock # Bi + transformer_layer: Transformer_Proj + lora_rank: 10 + +classifier: + name: MInfLoRA2 + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False + eval_mat: True diff --git a/config/zz_MInfLoRA/MInfLoRA3.yaml b/config/zz_MInfLoRA/MInfLoRA3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..821fd5ba19122a4afb554abfccca779f42141df3 --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA3.yaml @@ -0,0 +1,75 @@ +# Dataset CiFar100 +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 10 +epoch: &epoch 20 +image_size: &image_size 224 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 128 +testing_times: 1 #10 + +setting: task-agnostic + +# Distributed training +n_gpu: 1 +pin_memory: True + +train_trfms: + - RandomResizedCrop: + size: *image_size + - RandomHorizontalFlip: {} + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +test_trfms: + - Resize: + size: *image_size + - ToTensor: {} + - Normalize: + mean: [0., 0., 0.] + std: [1., 1., 1.] + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0.0 + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule + kwargs: + K: *epoch + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MultiMaskedLoRA3 + block_layer: ResidualAttentionBiBlock # Bi + transformer_layer: Transformer_Proj + lora_rank: 10 + +classifier: + name: MInfLoRA3 + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False + eval_mat: False \ No newline at end of file diff --git a/config/zz_MInfLoRA/MInfLoRA3_imgnr.yaml b/config/zz_MInfLoRA/MInfLoRA3_imgnr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..297fb751dabed133307672f3fdd8c420ec62cadc --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA3_imgnr.yaml @@ -0,0 +1,58 @@ +# Dataset ImageNet-R +dataset: &dataset "imagenet-r" +data_root: "/data/lqx/imagenet-r/" +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +epoch: &epoch 30 +image_size: &image_size 224 + +image_size: *image_size +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch +workers: 24 + +batch_size: 128 +testing_times: 1 + +setting: task-agnostic + +optimizer: + name: Adam + kwargs: + lr: 0.0005 + weight_decay: 0. + betas: [0.9, 0.999] + +lr_scheduler: + name: CosineSchedule # CosineAnnealingLR + kwargs: + # T_max: *epoch + K: *epoch + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MultiMaskedLoRA3 + block_layer: ResidualAttentionBiBlock # Bi + transformer_layer: Transformer_Proj + lora_rank: 10 + +classifier: + name: MInfLoRA3 + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False + eval_mat: False \ No newline at end of file diff --git a/config/zz_MInfLoRA/MInfLoRA_imgnr.yaml b/config/zz_MInfLoRA/MInfLoRA_imgnr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd34e6b0ca8676982832bc17bdd2ba008bc275dd --- /dev/null +++ b/config/zz_MInfLoRA/MInfLoRA_imgnr.yaml @@ -0,0 +1,53 @@ +save_path : "" # log file placing + +# Dataset ImageNet-R +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +epoch: &epoch 1 # 50 + +dataset: *dataset +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +epoch: *epoch +val_per_epoch: *epoch +workers: 24 + +batch_size: 128 + +setting: task-aware + +optimizer: + name: SGD + kwargs: + lr: 0.001 + momentum: 0.9 + weight_decay: 0.0 + +lr_scheduler: + name: CosineAnnealingLR + kwargs: + T_max: *epoch + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + attn_layer: MultiHeadAttention_MaskedLoRA + lora_rank: 10 + +classifier: + name: MInfLoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + lame: 1.0 + lamb: 0.95 + embd_dim: 768 + use_ca: False diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b10-10-10.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99d355c22321aab09fcae143a53df80b116e8c99 --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b10-10-10.yaml @@ -0,0 +1,76 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b2-2-50.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b2-2-50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3407c2630164d87b672ac88d731faeab37debe6c --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b2-2-50.yaml @@ -0,0 +1,76 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 64 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b5-5-20.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b5-5-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2105f1407455d00f0deec3817732889cf5e7e153 --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-cifar100-b5-5-20.yaml @@ -0,0 +1,76 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-imagenetr-b10-10-20.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-imagenetr-b10-10-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35dc973888c0b3f2a61a2a813b81a5523a37f05f --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-imagenetr-b10-10-20.yaml @@ -0,0 +1,78 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 200 +task_num: &task_num 20 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 1 + +dataset: 'imagenet-r' +data_root: /home/lvqiexuan/temp_data/imagenet-r + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-10-11.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-10-11.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de7c28ab9d6c900d62f719d5356e4730b0418644 --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-10-11.yaml @@ -0,0 +1,93 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 200 +task_num: &task_num 11 +image_size: &image_size 224 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ +class_order: [131, 181, 22, 172, 144, 92, 97, 187, 58, 93, 6, 70, 106, 68, + 153, 168, 179, 199, 29, 46, 9, 142, 134, 88, 193, 110, 26, + 32, 117, 112, 17, 39, 166, 13, 94, 138, 109, 147, 51, 101, + 59, 188, 116, 5, 170, 99, 100, 167, 180, 146, 65, 1, 104, + 43, 38, 184, 123, 171, 137, 162, 71, 44, 95, 174, 12, 7, + 54, 152, 21, 47, 28, 176, 34, 2, 132, 118, 42, 189, 150, + 14, 165, 41, 192, 45, 82, 128, 63, 57, 197, 160, 53, 75, + 108, 135, 121, 159, 183, 67, 169, 50, 87, 69, 89, 196, + 115, 19, 148, 96, 86, 11, 8, 60, 33, 173, 78, 4, 119, 105, + 182, 127, 177, 30, 186, 40, 49, 178, 76, 157, 161, 73, 164, + 151, 31, 74, 191, 27, 125, 198, 81, 20, 155, 114, 139, 36, + 61, 56, 145, 48, 16, 83, 62, 85, 126, 0, 102, 23, 3, 140, + 15, 195, 133, 113, 190, 141, 52, 163, 156, 80, 111, 90, 175, + 143, 120, 84, 18, 25, 79, 37, 154, 136, 64, 158, 24, 185, + 72, 35, 129, 55, 149, 91, 122, 77, 103, 124, 130, 66, 10, 107, 194, 98] + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 64 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-20-6.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-20-6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c4362a74fe111ea3d8409aea26e145eef6f0d54 --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-20-6.yaml @@ -0,0 +1,93 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 200 +task_num: &task_num 6 +image_size: &image_size 224 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ +class_order: [131, 181, 22, 172, 144, 92, 97, 187, 58, 93, 6, 70, 106, 68, + 153, 168, 179, 199, 29, 46, 9, 142, 134, 88, 193, 110, 26, + 32, 117, 112, 17, 39, 166, 13, 94, 138, 109, 147, 51, 101, + 59, 188, 116, 5, 170, 99, 100, 167, 180, 146, 65, 1, 104, + 43, 38, 184, 123, 171, 137, 162, 71, 44, 95, 174, 12, 7, + 54, 152, 21, 47, 28, 176, 34, 2, 132, 118, 42, 189, 150, + 14, 165, 41, 192, 45, 82, 128, 63, 57, 197, 160, 53, 75, + 108, 135, 121, 159, 183, 67, 169, 50, 87, 69, 89, 196, + 115, 19, 148, 96, 86, 11, 8, 60, 33, 173, 78, 4, 119, 105, + 182, 127, 177, 30, 186, 40, 49, 178, 76, 157, 161, 73, 164, + 151, 31, 74, 191, 27, 125, 198, 81, 20, 155, 114, 139, 36, + 61, 56, 145, 48, 16, 83, 62, 85, 126, 0, 102, 23, 3, 140, + 15, 195, 133, 113, 190, 141, 52, 163, 156, 80, 111, 90, 175, + 143, 120, 84, 18, 25, 79, 37, 154, 136, 64, 158, 24, 185, + 72, 35, 129, 55, 149, 91, 122, 77, 103, 124, 130, 66, 10, 107, 194, 98] + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 64 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-5-21.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-5-21.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f3432f79c125d22520362158e94f7710ffaee54 --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl-clip-tiny-b100-5-21.yaml @@ -0,0 +1,93 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 200 +task_num: &task_num 21 +image_size: &image_size 224 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ +class_order: [131, 181, 22, 172, 144, 92, 97, 187, 58, 93, 6, 70, 106, 68, + 153, 168, 179, 199, 29, 46, 9, 142, 134, 88, 193, 110, 26, + 32, 117, 112, 17, 39, 166, 13, 94, 138, 109, 147, 51, 101, + 59, 188, 116, 5, 170, 99, 100, 167, 180, 146, 65, 1, 104, + 43, 38, 184, 123, 171, 137, 162, 71, 44, 95, 174, 12, 7, + 54, 152, 21, 47, 28, 176, 34, 2, 132, 118, 42, 189, 150, + 14, 165, 41, 192, 45, 82, 128, 63, 57, 197, 160, 53, 75, + 108, 135, 121, 159, 183, 67, 169, 50, 87, 69, 89, 196, + 115, 19, 148, 96, 86, 11, 8, 60, 33, 173, 78, 4, 119, 105, + 182, 127, 177, 30, 186, 40, 49, 178, 76, 157, 161, 73, 164, + 151, 31, 74, 191, 27, 125, 198, 81, 20, 155, 114, 139, 36, + 61, 56, 145, 48, 16, 83, 62, 85, 126, 0, 102, 23, 3, 140, + 15, 195, 133, 113, 190, 141, 52, 163, 156, 80, 111, 90, 175, + 143, 120, 84, 18, 25, 79, 37, 154, 136, 64, 158, 24, 185, + 72, 35, 129, 55, 149, 91, 122, 77, 103, 124, 130, 66, 10, 107, 194, 98] + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 64 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - _convert_to_rgb: {} + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c061ab5b9dfadfb9b9488f8317a3d074f0a15e9 --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl.yaml @@ -0,0 +1,73 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 +val_per_epoch: 1 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MoE_MLP + experts_num: 2 + step: 1 + top_k : 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. diff --git a/config/zz_MoeAdapter4CL/moe_adapter4cl_vit.yaml b/config/zz_MoeAdapter4CL/moe_adapter4cl_vit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..591eddf754f17cc5f779e87f81d33ab2be258757 --- /dev/null +++ b/config/zz_MoeAdapter4CL/moe_adapter4cl_vit.yaml @@ -0,0 +1,70 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 1 # 1 +val_per_epoch: 1 # 1 + +train_batch_size: 128 +test_batch_size: 64 + +testing_times: 1 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224_in21k + experts_num: 2 + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: MOE_ADAPTER4CL + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + prompt_template : "a bad photo of a {}." + label_smoothing: 0. + diff --git a/config/zz_OCM/ocm-resnet18-5dataset-b10-10-5.yaml b/config/zz_OCM/ocm-resnet18-5dataset-b10-10-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39fa3f0c2d1b0225a8d5d30620d708e77476f01b --- /dev/null +++ b/config/zz_OCM/ocm-resnet18-5dataset-b10-10-5.yaml @@ -0,0 +1,58 @@ +dataset: &dataset 5-datasets +data_root: /data/Dataset/5-dataset +class_order: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] + +total_cls_num: &total_cls_num 50 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 5 +image_size: &image_size 32 + +image_size: *image_size +warmup: 0 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num + +batch_size: 10 +epoch: 1 +val_per_epoch: 1 + +testing_times: 1 +seed: 1993 + +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0001 + +backbone: + name: resnet18 + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: OnlineBuffer + kwargs: + buffer_size: 5000 + batch_size: 64 + input_size: [3, *image_size, *image_size] + # strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: OCM + kwargs: + num_class: *total_cls_num + task_num: *task_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + feat_dim: 512 + image_size: *image_size diff --git a/config/zz_OCM/ocm-resnet18-cifar100-b2-2-50.yaml b/config/zz_OCM/ocm-resnet18-cifar100-b2-2-50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b2752e434aa141e04aea649084d6a7d2b3608dd --- /dev/null +++ b/config/zz_OCM/ocm-resnet18-cifar100-b2-2-50.yaml @@ -0,0 +1,63 @@ +dataset: &dataset cifar100 +total_cls_num: &total_cls_num 100 +init_cls_num: &init_cls_num 10 #2 +inc_cls_num: &inc_cls_num 10 #2 +task_num: &task_num 10 #50 +image_size: &image_size 32 + +image_size: *image_size +warmup: 0 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num + +batch_size: 10 +epoch: 1 +val_per_epoch: 1 + +testing_times: 1 +seed: 1993 + +train_trfms: + - ToTensor: {} + - Normalize: + mean: [0.4913725490196078, 0.4823529411764706, 0.4466666666666667] # [x / 255 for x in [125.3, 123.0, 113.9]] + std: [0.2470588235294118, 0.2435294117647059, 0.2615686274509804] # [x / 255 for x in [63.0, 62.1, 66.7]] + +test_trfms: + - ToTensor: {} + - Normalize: + mean: [0.4913725490196078, 0.4823529411764706, 0.4466666666666667] + std: [0.2470588235294118, 0.2435294117647059, 0.2615686274509804] + +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0001 + +backbone: + name: resnet18 + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +buffer: + name: OnlineBuffer + kwargs: + buffer_size: 5000 + batch_size: 64 + input_size: [3, *image_size, *image_size] + # strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: OCM + kwargs: + num_class: *total_cls_num + task_num: *task_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + feat_dim: 512 + image_size: *image_size diff --git a/config/zz_OCM/ocm-resnet18-imagenetr-b20-20-10.yaml b/config/zz_OCM/ocm-resnet18-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ae9391e902ec9de30de04fdd100518042e5b2af --- /dev/null +++ b/config/zz_OCM/ocm-resnet18-imagenetr-b20-20-10.yaml @@ -0,0 +1,55 @@ +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +image_size: &image_size 224 + +image_size: *image_size +warmup: 0 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num + +batch_size: 10 +epoch: 1 +val_per_epoch: 1 + +testing_times: 1 +seed: 1993 + +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0001 + +backbone: + name: resnet18 + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + +buffer: + name: OnlineBuffer + kwargs: + buffer_size: 2000 # M * 200 + batch_size: 64 + input_size: [3, *image_size, *image_size] + # strategy: herding # random, equal_random, reservoir, herding + +classifier: + name: OCM + kwargs: + num_class: *total_cls_num + task_num: *task_num + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + feat_dim: 512 + image_size: *image_size diff --git a/config/zz_SD-LoRA/sd_lora-vit-cifar100-b10-10-10.yaml b/config/zz_SD-LoRA/sd_lora-vit-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b89c46527ef7d1657f84b136efe08e7ad5e9b25 --- /dev/null +++ b/config/zz_SD-LoRA/sd_lora-vit-cifar100-b10-10-10.yaml @@ -0,0 +1,73 @@ +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 128 + +seed: 42 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale : [0.05, 1.0] + ratio : [0.75, 1.33333] + - RandomHorizontalFlip: + p : 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: *image_size + interpolation : BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 8e-3 + momentum: 0.9 + +lr_scheduler: + name: Constant + +# vit_base_patch16_224.augreg2_in21k_ft_in1k +# vit_base_patch16_224 +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224.augreg2_in21k_ft_in1k + attn_layer: MultiHeadAttention_SDLoRA + lora_rank: 10 + +classifier: + name: SD_LoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + init_mag : 1.0 + rank_reduction: [False, 4, 8, 8, 6] + knowledge_dist: [False, 9e-4] + \ No newline at end of file diff --git a/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b10-10-20.yaml b/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b10-10-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8a9c0f1249740fa1714e37d0694b44702a3f737 --- /dev/null +++ b/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b10-10-20.yaml @@ -0,0 +1,78 @@ +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +task_num: &task_num 20 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 128 + +seed: 42 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.333] + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 + interpolation: BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 1e-2 + momentum: 0.9 + +# In source code + +lr_scheduler: + name: Constant + kwargs: + +# vit_base_patch16_224.augreg2_in21k_ft_in1k +# vit_base_patch16_224 +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224.augreg2_in21k_ft_in1k + attn_layer: MultiHeadAttention_SDLoRA + lora_rank: 10 + +classifier: + name: SD_LoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + init_mag: 1.0 + rank_reduction: [False, 4, 8, 8, 6] + knowledge_dist: [True, 9e-4] + \ No newline at end of file diff --git a/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b20-20-10.yaml b/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e8d62be9e8ca515722e7301242712247b535272 --- /dev/null +++ b/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b20-20-10.yaml @@ -0,0 +1,75 @@ +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +task_num: &task_num 10 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 128 + +seed: 1993 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.333] + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 + interpolation: BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 8e-3 + momentum: 0.9 + +lr_scheduler: + name: Constant + +# vit_base_patch16_224.augreg2_in21k_ft_in1k +# vit_base_patch16_224 +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224.augreg2_in21k_ft_in1k + attn_layer: MultiHeadAttention_SDLoRA + lora_rank: 10 + +classifier: + name: SD_LoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + init_mag: 1.0 + rank_reduction: [False, 4, 8, 8, 6] + knowledge_dist: [False, 9e-4] + \ No newline at end of file diff --git a/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b40-40-5.yaml b/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b40-40-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7a6ebfa0a1648de832c29c1a3dd2d20e61f165c --- /dev/null +++ b/config/zz_SD-LoRA/sd_lora-vit-imagenetr-b40-40-5.yaml @@ -0,0 +1,76 @@ +dataset: &dataset "imagenet-r" +data_root: "/home/lvqiexuan/temp_data/imagenet-r/" + +total_cls_num: &total_cls_num 200 +init_cls_num: &init_cls_num 40 +inc_cls_num: &inc_cls_num 40 +task_num: &task_num 5 +image_size: &image_size 224 +epoch: &epoch 20 + +dataset: *dataset +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +task_num: *task_num + +epoch: *epoch +val_per_epoch: *epoch + +batch_size: 128 + +seed: 42 + +setting: task-agnostic + +testing_times: 1 + +train_trfms: + - RandomResizedCrop: + size: *image_size + scale: [0.05, 1.0] + ratio: [0.75, 1.333] + - RandomHorizontalFlip: + p: 0.5 + - ToTensor: {} + +test_trfms: + - Resize: + size: 256 + interpolation: BICUBIC + - CenterCrop: + size: *image_size + - ToTensor: {} + +optimizer: + name: SGD + kwargs: + lr: 1e-2 + momentum: 0.9 + +lr_scheduler: + name: Constant + kwargs: + +# vit_base_patch16_224.augreg2_in21k_ft_in1k +# vit_base_patch16_224 +backbone: + name: vit_pt_imnet + kwargs: + pretrained: True + model_name : vit_base_patch16_224.augreg2_in21k_ft_in1k + attn_layer: MultiHeadAttention_SDLoRA + lora_rank: 10 + +classifier: + name: SD_LoRA + kwargs: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + embd_dim: 768 + init_mag: 1.0 + rank_reduction: [False, 4, 8, 8, 6] + knowledge_dist: [False, 9e-4] + \ No newline at end of file diff --git a/config/zz_TRGP/trgp_cil-alexnet-cifar100-b10-10-10.yaml b/config/zz_TRGP/trgp_cil-alexnet-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed4332480f0bf3b94497598e3d721fd54a9ef028 --- /dev/null +++ b/config/zz_TRGP/trgp_cil-alexnet-cifar100-b10-10-10.yaml @@ -0,0 +1,38 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. diff --git a/config/zz_TRGP/trgp_cil-alexnet-cifar100-b2-2-50.yaml b/config/zz_TRGP/trgp_cil-alexnet-cifar100-b2-2-50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bfc443511e1f0da252b628310d433ae956fe39f7 --- /dev/null +++ b/config/zz_TRGP/trgp_cil-alexnet-cifar100-b2-2-50.yaml @@ -0,0 +1,40 @@ +init_cls_num: &init_cls_num 2 +inc_cls_num: &inc_cls_num 2 +total_cls_num: &total_cls_num 100 +task_num: &task_num 50 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic +testing_times: 1 # Don't set too high, it will take eternity + + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. diff --git a/config/zz_TRGP/trgp_cil-alexnet-cifar100-b20-20-5.yaml b/config/zz_TRGP/trgp_cil-alexnet-cifar100-b20-20-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2247d5eab0f058a3b1f88a2018935d2b6aa2f41 --- /dev/null +++ b/config/zz_TRGP/trgp_cil-alexnet-cifar100-b20-20-5.yaml @@ -0,0 +1,38 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-agnostic + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. diff --git a/config/zz_TRGP/trgp_cil-alexnet-imagenetr-b10-10-10.yaml b/config/zz_TRGP/trgp_cil-alexnet-imagenetr-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58e80bf456bea33cfc80fb7cd8d735c7a736bf8c --- /dev/null +++ b/config/zz_TRGP/trgp_cil-alexnet-imagenetr-b10-10-10.yaml @@ -0,0 +1,41 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +dataset: 'imagenet-r' +data_root: /home/lvqiexuan/temp_data/imagenet-r + +setting: task-agnostic + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. diff --git a/config/zz_TRGP/trgp_cil-clip-cifar100-b10-10-10.yaml b/config/zz_TRGP/trgp_cil-clip-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbce2deefe2760a7d247716873ef8af780d87d1a --- /dev/null +++ b/config/zz_TRGP/trgp_cil-clip-cifar100-b10-10-10.yaml @@ -0,0 +1,66 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +train_batch_size: 128 +test_batch_size: 64 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MaskedMLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. + prompt_template : "a bad photo of a {}." \ No newline at end of file diff --git a/config/zz_TRGP/trgp_cil-clip-cifar100-b2-2-50.yaml b/config/zz_TRGP/trgp_cil-clip-cifar100-b2-2-50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..09efa03f49295fcc912086ca219400e75e7aafc9 --- /dev/null +++ b/config/zz_TRGP/trgp_cil-clip-cifar100-b2-2-50.yaml @@ -0,0 +1,67 @@ +init_cls_num: &init_cls_num 2 +inc_cls_num: &inc_cls_num 2 +total_cls_num: &total_cls_num 100 +task_num: &task_num 50 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +testing_times: 1 +train_batch_size: 128 +test_batch_size: 64 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MaskedMLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. + prompt_template : "a bad photo of a {}." \ No newline at end of file diff --git a/config/zz_TRGP/trgp_cil-clip-cifar100-b20-20-5.yaml b/config/zz_TRGP/trgp_cil-clip-cifar100-b20-20-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e004bc3bc3c41acca4fd5794376db3892e8620bf --- /dev/null +++ b/config/zz_TRGP/trgp_cil-clip-cifar100-b20-20-5.yaml @@ -0,0 +1,66 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +train_batch_size: 128 +test_batch_size: 64 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MaskedMLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *init_cls_num + task_num: *task_num + label_smoothing: 0. + prompt_template : "a bad photo of a {}." \ No newline at end of file diff --git a/config/zz_TRGP/trgp_cil-clip-cifar100-b5-5-20.yaml b/config/zz_TRGP/trgp_cil-clip-cifar100-b5-5-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f4c8dd5d8a781a044873d6a0577215ecdae41ed --- /dev/null +++ b/config/zz_TRGP/trgp_cil-clip-cifar100-b5-5-20.yaml @@ -0,0 +1,67 @@ +init_cls_num: &init_cls_num 5 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 20 +image_size: &image_size 224 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 # 4 +val_per_epoch: 4 # 4 + +train_batch_size: 128 +test_batch_size: 64 + +setting: task-agnostic +testing_times: 1 + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MaskedMLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. + prompt_template : "a bad photo of a {}." \ No newline at end of file diff --git a/config/zz_TRGP/trgp_cil-clip-tiny-b100-10-11.yaml b/config/zz_TRGP/trgp_cil-clip-tiny-b100-10-11.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d92c38c77727c389444d0810ad05fd9566a0241 --- /dev/null +++ b/config/zz_TRGP/trgp_cil-clip-tiny-b100-10-11.yaml @@ -0,0 +1,71 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 11 +image_size: &image_size 224 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ +device_ids: 0 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +testing_times: 1 +train_batch_size: 128 +test_batch_size: 64 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MaskedMLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. + prompt_template : "a bad photo of a {}." \ No newline at end of file diff --git a/config/zz_TRGP/trgp_cil-clip-tiny-b100-20-6.yaml b/config/zz_TRGP/trgp_cil-clip-tiny-b100-20-6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6cda50cea4536e719e4a63b9e10beab2fa2cb30e --- /dev/null +++ b/config/zz_TRGP/trgp_cil-clip-tiny-b100-20-6.yaml @@ -0,0 +1,71 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 6 +image_size: &image_size 224 + + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +testing_times: 1 +train_batch_size: 128 +test_batch_size: 64 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MaskedMLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. + prompt_template : "a bad photo of a {}." \ No newline at end of file diff --git a/config/zz_TRGP/trgp_cil-clip-tiny-b100-5-21.yaml b/config/zz_TRGP/trgp_cil-clip-tiny-b100-5-21.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68c972db63ac3f41ad3eede2a48bfe9a4e8223bf --- /dev/null +++ b/config/zz_TRGP/trgp_cil-clip-tiny-b100-5-21.yaml @@ -0,0 +1,71 @@ +init_cls_num: &init_cls_num 100 +inc_cls_num: &inc_cls_num 5 +total_cls_num: &total_cls_num 100 +task_num: &task_num 21 +image_size: &image_size 224 + +dataset: tiny-imagenet +data_root: /home/lvqiexuan/temp_data/ +device_ids: 0 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 4 +val_per_epoch: 4 + +testing_times: 1 +train_batch_size: 128 +test_batch_size: 64 + +setting: task-agnostic + +train_trfms: + - RandomResizedCrop : + size: *image_size + scale: [0.9, 1.0] + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +test_trfms: + - Resize : + size: *image_size + interpolation: BICUBIC + - ToTensor: {} + - Normalize: + mean: [0.48145466, 0.4578275, 0.40821073] + std: [0.26862954, 0.26130258, 0.27577711] + +optimizer: + name: AdamW + kwargs: + lr: 1e-3 + weight_decay: 0. + +lr_scheduler: + name: CosineAnnealingWarmUp + kwargs: + T_max: 0 # Will be replaced in trainter.py with epoch * len(dataloader) + warmup_length: 30 + +backbone: + name: clip + kwargs: + model_name : ViT-B/16 + pretrained : True + block_layer: ResidualAttentionBlock_MaskedMLP + act_layer: QuickGELU + norm_layer: LayerNorm + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. + prompt_template : "a bad photo of a {}." \ No newline at end of file diff --git a/config/zz_TRGP/trgp_til-alexnet-cifar100-b10-10-10.yaml b/config/zz_TRGP/trgp_til-alexnet-cifar100-b10-10-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d29a14d70ac72e4b0f70a641b9cd284c6cdb4617 --- /dev/null +++ b/config/zz_TRGP/trgp_til-alexnet-cifar100-b10-10-10.yaml @@ -0,0 +1,38 @@ +init_cls_num: &init_cls_num 10 +inc_cls_num: &inc_cls_num 10 +total_cls_num: &total_cls_num 100 +task_num: &task_num 10 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-aware + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. diff --git a/config/zz_TRGP/trgp_til-alexnet-cifar100-b20-20-5.yaml b/config/zz_TRGP/trgp_til-alexnet-cifar100-b20-20-5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21007925fdad899e213497a7e67f53d0b68f4faa --- /dev/null +++ b/config/zz_TRGP/trgp_til-alexnet-cifar100-b20-20-5.yaml @@ -0,0 +1,38 @@ +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 + +task_num: *task_num +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +total_cls_num: *total_cls_num +epoch: 200 +batch_size: 64 +val_per_epoch: 200 + +setting: task-aware + +optimizer: + name: SGD + kwargs: + lr: 0.01 + +lr_scheduler: + name: PatienceSchedule + kwargs: + patience: 6 + factor: 2 + stopping_lr: 1e-5 + +backbone: + name: AlexNet_TRGP + kwargs: + +classifier: + name: TRGP + kwargs: + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + task_num: *task_num + label_smoothing: 0. diff --git a/config/zz_WA/wa-resnet18-imagenetr-b20-20-10.yaml b/config/zz_WA/wa-resnet18-imagenetr-b20-20-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89ba2db7729ac3cda65e316bcafbf63cd0bf8888 --- /dev/null +++ b/config/zz_WA/wa-resnet18-imagenetr-b20-20-10.yaml @@ -0,0 +1,62 @@ + +dataset: &dataset imagenet-r +data_root: /data/Dataset/imagenet-r + +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 200 +task_num: &task_num 10 +image_size: &image_size 224 + +# Image size +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num +image_size: *image_size + +# Training setting +init_epoch: 1 #250 +epoch: 1 #250 +val_per_epoch: 250 +batch_size: 10 + +# Optimizer settings +optimizer: + name: SGD + kwargs: + lr: 0.1 + weight_decay: 2e-4 + momentum: 0.9 + +# Learning rate scheduler settings +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [100, 150, 200] + +# Backbone architecture settings +backbone: + name: resnet18 + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + init_cls_num: *init_cls_num + inc_cls_num: *inc_cls_num + +# Buffer settings +buffer: + name: LinearHerdingBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 +# strategy: herding # random, equal_random, reservoir, herding + +# Classifier settings +classifier: + name: WA + kwargs: + num_class: *total_cls_num + feat_dim: 512 # 64 + init_cls_num: *init_cls_num diff --git a/config/zz_WA/wa.yaml b/config/zz_WA/wa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0575035c3eaf86156400adbd8e117a78a1bc7ea --- /dev/null +++ b/config/zz_WA/wa.yaml @@ -0,0 +1,59 @@ + +dataset: &dataset cifar100 +init_cls_num: &init_cls_num 20 +inc_cls_num: &inc_cls_num 20 +total_cls_num: &total_cls_num 100 +task_num: &task_num 5 +image_size: &image_size 32 + +# Image size +dataset: cifar +init_cls_num: *init_cls_num +inc_cls_num: *inc_cls_num +task_num: *task_num +image_size: *image_size + +# Training setting +init_epoch: 1 #250 +epoch: 1 #250 +val_per_epoch: 250 +batch_size: 128 + +# Optimizer settings +optimizer: + name: SGD + kwargs: + lr: 0.1 + weight_decay: 2e-4 + momentum: 0.9 + +# Learning rate scheduler settings +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [100, 150, 200] + +# Backbone architecture settings +backbone: + name: resnet18 + kwargs: + num_classes: *total_cls_num + args: + dataset: *dataset + +# Buffer settings +buffer: + name: LinearHerdingBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 +# strategy: herding # random, equal_random, reservoir, herding + +# Classifier settings +classifier: + name: WA + kwargs: + num_class: *total_cls_num + feat_dim: 512 # 64 + init_cls_num: *init_cls_num diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f69dc63f498715bf9a2705183cc4d379eef4b3cd --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +from .trainer import Trainer diff --git a/core/config/__init__.py b/core/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27c9ec622ac08b8d9d1e29fcc1cea3e8afa783b0 --- /dev/null +++ b/core/config/__init__.py @@ -0,0 +1 @@ +from .config import * diff --git a/core/config/config.py b/core/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8d90af0383f21081571bea0f9498b8742cf82847 --- /dev/null +++ b/core/config/config.py @@ -0,0 +1,134 @@ +import argparse +import os +import random +import yaml +import re + +def get_cur_path(): + """Get the absolute path of current file. + + Returns: The absolute path of this file (Config.py). + + """ + return os.path.dirname(__file__) + + +DEFAULT_FILE = os.path.join(get_cur_path(), "default.yaml") + +class Config(object): + """ The config parser of `LibContinual` + `Config` is used to parser *.yaml, console params to python dict. The rules for resolving merge conflicts are as follow + + 1. The merging is recursive, if a key is not be specified, the existing value will be used. + 2. The merge priority is: console params > run_*.py > user defined yaml (/LibContinual/config/*.yaml) > default.yaml(/LibContinual/core/config/*.yaml) + """ + + def __init__(self, config_file=None): + """Initializing the parameter dictionary, completes the merging of all parameter. + + Args: + config_file: Configuration file name. (/LibContinual/config/*.yaml) + """ + self.config_file = config_file + self.default_dict = self._load_config_files(DEFAULT_FILE) + self.file_dict = self._load_config_files(config_file) + self.console_dict = self._load_console_dict() + self.config_dict = self._merge_config_dict() + + def get_config_dict(self): + """ Return the merged dict. + + Returns: + dict: A dict of LibContinual setting. + """ + return self.config_dict + + + @staticmethod + def _load_config_files(config_file): + """Parse a YAML file. + + Args: + config_file (str): Path to yaml file. + + Returns: + dict: A dict of LibContinual setting. + """ + config_dict = dict() + loader = yaml.SafeLoader + loader.add_implicit_resolver( + "tag:yaml.org,2002:float", + re.compile( + """^(?: + [-+]?[0-9][0-9_]*\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?[0-9][0-9_]*[eE][-+]?[0-9]+ + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$""", + re.X, + ), + list("-+0123456789."), + ) + + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as fin: + config_dict.update(yaml.load(fin.read(), Loader=loader)) + config_file_dict = config_dict.copy() + for include in config_dict.get("includes", []): + with open(os.path.join("./config/", include), "r", encoding="utf-8") as fin: + config_dict.update(yaml.load(fin.read(), Loader=loader)) + if config_dict.get("includes") is not None: + config_dict.pop("includes") + config_dict.update(config_file_dict) + return config_dict + + @staticmethod + def _load_console_dict(): + """Parsing command line parameters + + Returns: + dict: A dict of LibContinual console setting. + """ + pass + + @staticmethod + def _update(dic1, dic2): + """Merge dictionaries. + + Used to merge two dictionaries (profiles), `dic2` will overwrite the value of the same key in `dic1`. + + + Args: + dic1 (dict): The dict to be overwritten. (low priority) + dic2 (dict): The dict to overwrite. (high priority) + + Returns: + dict: Merged dict. + """ + + if dic1 is None: + dic1 = dict() + + if dic2 is not None: + for k in dic2.keys(): + dic1[k] = dic2[k] + return dic1 + + + def _merge_config_dict(self): + """Merge all dictionaries. Merge rules are as follow + + 1. The merging is recursive, if a key is not be specified, the existing value will be used. + 2. The merge priority is: console params > run_*.py > user defined yaml (/LibContinual/config/*.yaml) > default.yaml(/LibContinual/core/config/*.yaml) + + Returns: + dict: A complete dict of LibContinual setting. + """ + + config_dict = dict() + config_dict = self._update(config_dict, self.default_dict) + config_dict = self._update(config_dict, self.file_dict) + config_dict = self._update(config_dict, self.console_dict) + + return config_dict \ No newline at end of file diff --git a/core/config/default.yaml b/core/config/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..349ce7acc1b60ccdd858be9154ff82318fa71a14 --- /dev/null +++ b/core/config/default.yaml @@ -0,0 +1,6 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/model.yaml + - headers/optimizer.yaml + - headers/test.yaml diff --git a/core/data/__init__.py b/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ad822cc21eeea46765af97c82261d4a37b4601 --- /dev/null +++ b/core/data/__init__.py @@ -0,0 +1 @@ +from .dataloader import * \ No newline at end of file diff --git a/core/data/custom_transforms.py b/core/data/custom_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8b343ca470294a8470cc85c4ed580e9e99c66d0c --- /dev/null +++ b/core/data/custom_transforms.py @@ -0,0 +1,9 @@ +# ----------------- +# Custom Transfrom : Define your own custom transform function, and add to the list + +custom_trfm_names = ['_convert_to_rgb'] + +def _convert_to_rgb(img): + return img.convert('RGB') + +# ----------------- \ No newline at end of file diff --git a/core/data/data.py b/core/data/data.py new file mode 100644 index 0000000000000000000000000000000000000000..253fe2458074bac5e58131dd3ecaf71ac010b3cf --- /dev/null +++ b/core/data/data.py @@ -0,0 +1,340 @@ +import numpy as np +from torchvision import transforms + +class CIFARTransform: + MEAN = [0.5071, 0.4866, 0.4409] + STD = [0.2675, 0.2565, 0.2761] + + common_trfs = [transforms.ToTensor(), + transforms.Normalize(mean=MEAN, std=STD)] + + resnet_train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255), + *common_trfs + ]) + + resnet_test_transform = transforms.Compose([*common_trfs]) + + # To Reproduce ERAML, ERACE + #resnet_train_transform = transforms.Compose([*common_trfs]) + + # from + dset_mean = (0., 0., 0.) + dset_std = (1., 1., 1.) + vit_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std)]) + + vit_test_transform = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std)]) + + # from trust region gradient projection + mean=[x/255 for x in [125.3,123.0,113.9]] + std=[x/255 for x in [63.0,62.1,66.7]] + + alexnet_train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + alexnet_test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + @staticmethod + def get_transform(model_type, mode): + if model_type == 'resnet': + if mode == 'train': + return CIFARTransform.resnet_train_transform + elif mode == 'test': + return CIFARTransform.resnet_test_transform + elif model_type == 'vit': + if mode == 'train': + return CIFARTransform.vit_train_transform + elif mode == 'test': + return CIFARTransform.vit_test_transform + elif model_type == 'alexnet': + if mode == 'train': + return CIFARTransform.alexnet_train_transform + elif mode == 'test': + return CIFARTransform.alexnet_test_transform + else: + raise ValueError("Unsupported model type") + +class ImageNetTransform: + MEAN=[0.4914, 0.4822, 0.4465] + STD=[0.2023, 0.1994, 0.2010] + + common_trfs = [transforms.ToTensor(), + transforms.Normalize(mean=MEAN, std=STD)] + + resnet_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255), + *common_trfs + ]) + + resnet_test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + *common_trfs + ]) + + + dset_mean = (0., 0., 0.) + dset_std = (1., 1., 1.) + vit_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std), + ]) + + vit_test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std), + ]) + + @staticmethod + def get_transform(model_type, mode): + if model_type == 'resnet': + if mode == 'train': + return ImageNetTransform.resnet_train_transform + elif mode == 'test': + return ImageNetTransform.resnet_test_transform + elif model_type == 'vit': + if mode == 'train': + return ImageNetTransform.vit_train_transform + elif mode == 'test': + return ImageNetTransform.vit_test_transform + else: + raise ValueError("Unsupported model type") + +class ImageNetRTransform: + mean = [0.4914, 0.4822, 0.4465] + std = [0.2023, 0.1994, 0.2010] + + common_trfs = [transforms.ToTensor(), + transforms.Normalize(mean, std)] + + resnet_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255), + *common_trfs]) + + resnet_test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + *common_trfs]) + + mean = [0., 0., 0.] + std = [1., 1., 1.] + + common_trfs = [transforms.ToTensor(), + transforms.Normalize(mean, std)] + + vit_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + *common_trfs]) + + vit_test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + *common_trfs]) + + + # from trust region gradient projection + mean=[x/255 for x in [125.3,123.0,113.9]] + std=[x/255 for x in [63.0,62.1,66.7]] + + alexnet_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(32), + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + alexnet_test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(32), + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + @staticmethod + def get_transform(model_type, mode): + if model_type == 'resnet': + if mode == 'train': + return ImageNetRTransform.resnet_train_transform + elif mode == 'test': + return ImageNetRTransform.resnet_test_transform + elif model_type == 'vit': + if mode == 'train': + return ImageNetRTransform.vit_train_transform + elif mode == 'test': + return ImageNetRTransform.vit_test_transform + elif model_type == 'alexnet': + if mode == 'train': + return ImageNetRTransform.alexnet_train_transform + elif mode == 'test': + return ImageNetRTransform.alexnet_test_transform + else: + raise ValueError("Unsupported model type") + +class TinyImageNetTransform: + # Standard normalization values for Tiny-ImageNet + MEAN = [0.485, 0.456, 0.406] + STD = [0.229, 0.224, 0.225] + + common_trfs = [transforms.ToTensor(), + transforms.Normalize(mean=MEAN, std=STD)] + + # ResNet Transforms + resnet_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(64), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255), + *common_trfs + ]) + + resnet_test_transform = transforms.Compose([ + transforms.Resize(64), + transforms.CenterCrop(64), + *common_trfs + ]) + + # ViT Transforms (Using dataset mean/std as [0,0,0] and [1,1,1] for compatibility) + dset_mean = (0., 0., 0.) + dset_std = (1., 1., 1.) + + vit_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std) + ]) + + vit_test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std) + ]) + + # from trust region gradient projection + mean=[x/255 for x in [125.3,123.0,113.9]] + std=[x/255 for x in [63.0,62.1,66.7]] + + alexnet_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(32), + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + alexnet_test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(32), + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + @staticmethod + def get_transform(model_type, mode): + if model_type == 'resnet': + if mode == 'train': + return TinyImageNetTransform.resnet_train_transform + elif mode == 'test': + return TinyImageNetTransform.resnet_test_transform + elif model_type == 'vit': + if mode == 'train': + return TinyImageNetTransform.vit_train_transform + elif mode == 'test': + return TinyImageNetTransform.vit_test_transform + elif model_type == 'alexnet': + if mode == 'train': + return TinyImageNetTransform.alexnet_train_transform + elif mode == 'test': + return TinyImageNetTransform.alexnet_test_transform + else: + raise ValueError("Unsupported model type") + + +class FiveDatasetsTransform: + MEAN = [0.5071, 0.4866, 0.4409] + STD = [0.2675, 0.2565, 0.2761] + + common_trfs = [transforms.ToTensor(), + transforms.Normalize(mean=MEAN, std=STD)] + + resnet_train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255), + *common_trfs + ]) + + resnet_test_transform = transforms.Compose([ + transforms.Resize(32), + *common_trfs + ]) + + # from + dset_mean = (0., 0., 0.) + dset_std = (1., 1., 1.) + vit_train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std)]) + + vit_test_transform = transforms.Compose([ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize(dset_mean, dset_std)]) + + # from trust region gradient projection + mean=[x/255 for x in [125.3,123.0,113.9]] + std=[x/255 for x in [63.0,62.1,66.7]] + + alexnet_train_transform = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + alexnet_test_transform = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize(mean,std)]) + + @staticmethod + def get_transform(model_type, mode): + if model_type == 'resnet': + if mode == 'train': + return FiveDatasetsTransform.resnet_train_transform + elif mode == 'test': + return FiveDatasetsTransform.resnet_test_transform + elif model_type == 'vit': + if mode == 'train': + return FiveDatasetsTransform.vit_train_transform + elif mode == 'test': + return FiveDatasetsTransform.vit_test_transform + elif model_type == 'alexnet': + if mode == 'train': + return FiveDatasetsTransform.alexnet_train_transform + elif mode == 'test': + return FiveDatasetsTransform.alexnet_test_transform + else: + raise ValueError("Unsupported model type") + +transform_classes = { + 'cifar': CIFARTransform, + 'imagenet': ImageNetTransform, + 'imagenet-r': ImageNetRTransform, + 'tiny-imagenet': TinyImageNetTransform, + '5-datasets': FiveDatasetsTransform +} \ No newline at end of file diff --git a/core/data/dataloader.py b/core/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..64636d70bb2a50e24132e3c3aad91751a8823a1c --- /dev/null +++ b/core/data/dataloader.py @@ -0,0 +1,129 @@ +import os +import random +import numpy as np +import core.data.custom_transforms as cstf + +from torchvision import datasets, transforms +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from .dataset import ContinualDatasets, ImbalancedDatasets +from .data import transform_classes +from PIL import Image +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +def _create_transforms(cfg): + transform_list = [] + + for item in cfg: + for func_name, params in item.items(): + + # Convert str to enum, if required + for k, v in params.items(): + if isinstance(v, str): + try: + params[k] = transforms.InterpolationMode[v] + except KeyError: + pass + + if func_name in cstf.custom_trfm_names: + transform = getattr(cstf, func_name) + else: + transform = getattr(transforms, func_name)(**params) + + transform_list.append(transform) + + return transforms.Compose(transform_list) + +def get_augment(config, mode='train'): + # Special judge for RAPF + if 'is_rapf' in config.keys() and config['is_rapf']: + def _convert_image_to_rgb(image): + return image.convert("RGB") + n_px = config['image_size'] + + return Compose([ + transforms.Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + if f'{mode}_trfms' in config.keys(): + return _create_transforms(config[f'{mode}_trfms']) + + # TODO: currently keeping below part for backward compatibility, will be remove in future + + d = {'dataset': 'cifar', + 'backbone': 'resnet', + 'mode': mode} + + if 'dataset' in config.keys(): + if 'cifar' in config['dataset']: + d['dataset'] = 'cifar' + else: + d['dataset'] = config['dataset'] + + if 'vit' in config['backbone']['name'].lower(): + d['backbone'] = 'vit' + if 'alexnet' in config['backbone']['name'].lower(): + d['backbone'] = 'alexnet' + + return transform_classes[d['dataset']].get_transform(d['backbone'], d['mode']) + +def get_dataloader(config, mode, cls_map=None): + ''' + Initialize the dataloaders for Continual Learning. + + Args: + config (dict): Parsed config dict. + mode (string): 'trian' or 'test'. + cls_map (dict): record the map between class and labels. + + Returns: + Dataloaders (list): a list of dataloaders + ''' + + task_num = config['task_num'] + init_cls_num = config['init_cls_num'] + inc_cls_num = config['inc_cls_num'] + + data_root = config['data_root'] + num_workers = config['num_workers'] + dataset = config['dataset'] + + trfms = get_augment(config, mode) + + if f'{mode}_batch_size' in config.keys(): + batch_size = config[f'{mode}_batch_size'] + else: + batch_size = config['batch_size'] + + if dataset == 'tiny-imagenet': + cls_map = {} + with open(os.path.join(os.getcwd(), "core", "data", "dataset_reqs", f"tinyimagenet_classes.txt"), "r") as f: + for line in f.readlines(): + _, cls_code, cls_name = line.strip().split('\t') + cls_map[cls_code] = cls_name + + elif cls_map is None and dataset != 'binary_cifar100': + # Apply class_order for debugging + cls_list = sorted(os.listdir(os.path.join(data_root, mode))) + #random.shuffle(cls_list) + if 'class_order' in config.keys(): + class_order = config['class_order'] + perm = class_order + else: + perm = np.random.permutation(len(cls_list)) + cls_map = dict() + for label, ori_label in enumerate(perm): + cls_map[label] = cls_list[ori_label] + + if mode == 'train' and 'imb_type' in config.keys(): + # generate long-tailed data to reproduce DAP + return ImbalancedDatasets(mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batch_size, num_workers, config['imb_type'], config['imb_factor'], config['shuffle']) + + return ContinualDatasets(dataset, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batch_size, num_workers, config) + diff --git a/core/data/dataset.py b/core/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e0fe5cbb305487cdd62dacfb8d2f5a33453ff2b4 --- /dev/null +++ b/core/data/dataset.py @@ -0,0 +1,304 @@ +import os +import torch +import pickle +import random +import numpy as np + +from PIL import Image +from torchvision import datasets +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from continuum.datasets import TinyImageNet200 +from continuum import ClassIncremental + +class ContinualDatasets: + def __init__(self, dataset, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers, config): + self.mode = mode + self.task_num = task_num + self.init_cls_num = init_cls_num + self.inc_cls_num = inc_cls_num + self.data_root = data_root + self.cls_map = cls_map + self.trfms = trfms + self.batchsize = batchsize + self.num_workers = num_workers + self.config = config + self.dataset = dataset + + if self.dataset == 'binary_cifar100': + datasets.CIFAR100(self.data_root, download = True) + + self.create_loaders() + + def create_loaders(self): + self.dataloaders = [] + + if self.dataset == 'tiny-imagenet': + + if 'class_order' in self.config: + class_order = self.config['class_order'] + else: + class_order = list(range(200)) + random.seed(self.config['seed']) + random.shuffle(class_order) + + scenario = ClassIncremental( + TinyImageNet200(self.data_root, train=self.mode == 'train', download=True), + initial_increment=self.init_cls_num, + increment=self.inc_cls_num, + class_order=class_order + ) + + class_ids_per_task = ( + [class_order[:self.init_cls_num]] + + [class_order[i:i + self.inc_cls_num] for i in range(self.init_cls_num, len(class_order), self.inc_cls_num)] + ) + + with open(os.path.join(os.getcwd(), "core", "data", "dataset_reqs", f"tinyimagenet_classes.txt"), "r") as f: + lines = f.read().splitlines() + classes_names = [line.split("\t")[-1] for line in lines] + + for t in range(self.task_num): + + cur_scenario = scenario[t:t+1] + + dataset = SingleDataset(self.dataset, self.data_root, self.mode, self.init_cls_num, self.inc_cls_num, self.cls_map, self.trfms, init=False) + dataset.images = cur_scenario._x + dataset.labels = cur_scenario._y + dataset.labels_name = [classes_names[class_id] for class_id in class_ids_per_task[t]] + + self.dataloaders.append(DataLoader( + dataset, + shuffle = True, + batch_size = self.batchsize, + drop_last = False, + num_workers = self.num_workers, + pin_memory=self.config['pin_memory'] + )) + + else: + + for i in range(self.task_num): + + start_idx = 0 if i == 0 else (self.init_cls_num + (i-1) * self.inc_cls_num) + end_idx = start_idx + (self.init_cls_num if i ==0 else self.inc_cls_num) + self.dataloaders.append(DataLoader( + SingleDataset(self.dataset, self.data_root, self.mode, self.init_cls_num, self.inc_cls_num, self.cls_map, self.trfms, start_idx, end_idx), + shuffle = True, + batch_size = self.batchsize, + drop_last = False, + num_workers = self.num_workers, + pin_memory=False + )) + + def get_loader(self, task_idx): + assert task_idx >= 0 and task_idx < self.task_num + if self.mode == 'train': + return self.dataloaders[task_idx] + else: + return self.dataloaders[:task_idx+1] + +class ImbalancedDatasets(ContinualDatasets): + def __init__(self, mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers, imb_type='exp', imb_factor=0.002, shuffle=False): + self.imb_type = imb_type + self.imb_factor = imb_factor + self.shuffle = shuffle + super().__init__(mode, task_num, init_cls_num, inc_cls_num, data_root, cls_map, trfms, batchsize, num_workers) + + def create_loaders(self): + self.dataloaders = [] + cls_num = self.init_cls_num + self.inc_cls_num * (self.task_num - 1) + img_num_list = self._get_img_num_per_cls(cls_num, self.imb_type, self.imb_factor) + + if self.shuffle: + grouped_img_nums = [img_num_list[i:i + self.inc_cls_num] for i in range(0, cls_num, self.inc_cls_num)] + np.random.shuffle(grouped_img_nums) + for group in grouped_img_nums: + np.random.shuffle(group) + shuffled_img_num_list = [num for group in grouped_img_nums for num in group] + img_num_list = shuffled_img_num_list + + for i in range(self.task_num): + start_idx = 0 if i == 0 else (self.init_cls_num + (i - 1) * self.inc_cls_num) + end_idx = start_idx + (self.init_cls_num if i == 0 else self.inc_cls_num) + dataset = SingleDataset(self.data_root, self.mode, self.cls_map, self.trfms, start_idx, end_idx) + + new_imgs, new_labels = [], [] + labels_np = np.array(dataset.labels, dtype=np.int64) + classes = np.unique(labels_np) + for the_class, the_img_num in zip(classes, img_num_list[i * self.inc_cls_num:(i + 1) * self.inc_cls_num]): + idx = np.nonzero(labels_np == the_class)[0] + np.random.shuffle(idx) + selec_idx = idx[:the_img_num] + new_imgs.extend([dataset.images[j] for j in selec_idx]) + new_labels.extend([the_class, ] * the_img_num) + dataset.images = new_imgs + dataset.labels = new_labels + + self.dataloaders.append(DataLoader( + dataset, + batch_size = self.batchsize, + drop_last = False + )) + + def _get_img_num_per_cls(self, cls_num, imb_type, imb_factor): + img_max = len(os.listdir(os.path.join(self.data_root, self.mode, self.cls_map[0]))) + img_num_per_cls = [] + if imb_type == 'exp': + for cls_idx in range(cls_num): + num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) + img_num_per_cls.append(max(int(num), 1)) + elif imb_type == 'exp_re': + for cls_idx in range(cls_num): + num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) + img_num_per_cls.append(max(int(num), 1)) + img_num_per_cls.reverse() + elif imb_type == 'exp_max': + cls_per_group = cls_num//self.task_num + for cls_idx in range(cls_num): + if (cls_idx+1)%cls_per_group==1: + num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) + img_num_per_cls.append(int(num)) + elif imb_type == 'exp_max_re': + cls_per_group = cls_num//self.task_num + for cls_idx in range(cls_num): + if (cls_idx+1)%cls_per_group==1: + # print(cls_idx) + num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) + img_num_per_cls.append(int(num)) + img_num_per_cls.reverse() + + elif imb_type == 'exp_min': + cls_per_group = cls_num//self.task_num + for cls_idx in range(cls_num): + if (cls_idx+1)%cls_per_group==1: + # print(cls_idx) + num = img_max * (imb_factor**((cls_idx+cls_per_group-1) / (cls_num - 1.0))) + # print(num) + img_num_per_cls.append(int(num)) + + elif imb_type == 'half': + cls_per_group = cls_num // self.task_num + ratio = 2 + num = 1 + for cls_idx in range(cls_num): + if num > img_max: + num = img_max + img_num_per_cls.append(int(num)) + if (cls_idx + 1) % cls_per_group == 0: + num *= ratio + img_num_per_cls.reverse() + + elif imb_type == 'half_re': + cls_per_group = cls_num // self.task_num + ratio = 2 + num = 1 + for cls_idx in range(cls_num): + if num > img_max: + num = img_max + img_num_per_cls.append(int(num)) + if (cls_idx + 1) % cls_per_group == 0: + num *= ratio + + elif imb_type == 'halfbal': + cls_per_group = cls_num // self.task_num + N = img_max * cls_per_group + + total = 0 + for i in range(self.task_num): + total += N / (2**i) + print(total) + per_class_count = int(total / cls_num) + img_num_per_cls.extend([per_class_count] * cls_num) + + elif imb_type == 'oneshot': + img_num_per_cls.extend([1] * cls_num) + elif imb_type == 'step': + for cls_idx in range(cls_num // 2): + img_num_per_cls.append(int(img_max)) + for cls_idx in range(cls_num // 2): + img_num_per_cls.append(int(img_max * imb_factor)) + elif imb_type == 'fewshot': + for cls_idx in range(cls_num): + if cls_idx<50: + num = img_max + else: + num = img_max*0.01 + img_num_per_cls.append(int(num)) + else: + img_num_per_cls.extend([int(img_max)] * cls_num) + return img_num_per_cls + +class SingleDataset(Dataset): + def __init__(self, dataset, data_root, mode, init_cls_num, inc_cls_num, cls_map, trfms, start_idx=-1, end_idx=-1, init=True): + super().__init__() + self.dataset = dataset + self.data_root = data_root + self.mode = mode + self.init_cls_num = init_cls_num + self.inc_cls_num = inc_cls_num + self.cls_map = cls_map + self.start_idx = start_idx + self.end_idx = end_idx + self.trfms = trfms + + if init: + self.images, self.labels, self.labels_name = self._init_datalist() + + def __getitem__(self, idx): + if self.dataset == 'binary_cifar100': + + image = self.images[idx] + image = Image.fromarray(np.uint8(image)) + + elif self.dataset == 'tiny-imagenet': + img_path = self.images[idx] + image = Image.open(img_path).convert("RGB") + + else: + + img_path = self.images[idx] + image = Image.open(os.path.join(self.data_root, self.mode, img_path)).convert("RGB") + + label = self.labels[idx] + image = self.trfms(image) + + return {"image": image, "label": label} + + def __len__(self,): + return len(self.labels) + + def _init_datalist(self): + + imgs, labels, labels_name = [], [], [] + + if self.dataset == 'binary_cifar100': + + with open(os.path.join(self.data_root, 'cifar-100-python', self.mode), 'rb') as f: + load_data = pickle.load(f, encoding='latin1') + + for data, label in zip(load_data['data'], load_data['fine_labels']): + + if label in range(self.start_idx, self.end_idx): + r = data[:1024].reshape(32, 32) + g = data[1024:2048].reshape(32, 32) + b = data[2048:].reshape(32, 32) + + tt_data = np.dstack((r, g, b)) + + imgs.append(tt_data) + labels.append(label) + labels_name.append(label) + + else: + + for id in range(self.start_idx, self.end_idx): + img_list = [self.cls_map[id] + '/' + pic_path for pic_path in os.listdir(os.path.join(self.data_root, self.mode, self.cls_map[id]))] + imgs.extend(img_list) + labels.extend([id for _ in range(len(img_list))]) + labels_name.append(self.cls_map[id]) + + return imgs, labels, labels_name + + def get_class_names(self): + return self.labels_name \ No newline at end of file diff --git a/core/data/dataset_reqs/tinyimagenet_classes.txt b/core/data/dataset_reqs/tinyimagenet_classes.txt new file mode 100644 index 0000000000000000000000000000000000000000..db00294ce4dc30b8e3e3d56fd5576264cc94e2bf --- /dev/null +++ b/core/data/dataset_reqs/tinyimagenet_classes.txt @@ -0,0 +1,200 @@ +0 n02124075 Egyptian Mau +1 n04067472 fishing casting reel +2 n04540053 volleyball +3 n04099969 rocking chair +4 n07749582 lemon +5 n01641577 American bullfrog +6 n02802426 basketball +7 n09246464 cliff +8 n07920052 espresso +9 n03970156 plunger +10 n03891332 parking meter +11 n02106662 German Shepherd Dog +12 n03201208 dining table +13 n02279972 monarch butterfly +14 n02132136 brown bear +15 n04146614 school bus +16 n07873807 pizza +17 n02364673 guinea pig +18 n04507155 umbrella +19 n03854065 pipe organ +20 n03838899 oboe +21 n03733131 maypole +22 n01443537 goldfish +23 n07875152 pot pie +24 n03544143 hourglass +25 n09428293 beach +26 n03085013 computer keyboard +27 n02437312 arabian camel +28 n07614500 ice cream +29 n03804744 metal nail +30 n04265275 space heater +31 n02963159 cardigan +32 n02486410 baboon +33 n01944390 snail +34 n09256479 coral reef +35 n02058221 albatross +36 n04275548 spider web +37 n02321529 sea cucumber +38 n02769748 backpack +39 n02099712 Labrador Retriever +40 n07695742 pretzel +41 n02056570 king penguin +42 n02281406 sulphur butterfly +43 n01774750 tarantula +44 n02509815 red panda +45 n03983396 soda bottle +46 n07753592 banana +47 n04254777 sock +48 n02233338 cockroach +49 n04008634 missile +50 n02823428 beer bottle +51 n02236044 praying mantis +52 n03393912 freight car +53 n07583066 guacamole +54 n04074963 remote control +55 n01629819 fire salamander +56 n09332890 lakeshore +57 n02481823 chimpanzee +58 n03902125 payphone +59 n03404251 fur coat +60 n09193705 mountain +61 n03637318 lampshade +62 n04456115 torch +63 n02666196 abacus +64 n03796401 moving van +65 n02795169 barrel +66 n02123045 tabby cat +67 n01855672 goose +68 n01882714 koala +69 n02917067 high-speed train +70 n02988304 CD player +71 n04398044 teapot +72 n02843684 birdhouse +73 n02423022 gazelle +74 n02669723 academic gown +75 n04465501 tractor +76 n02165456 ladybug +77 n03770439 miniskirt +78 n02099601 Golden Retriever +79 n04486054 triumphal arch +80 n02950826 cannon +81 n03814639 neck brace +82 n04259630 sombrero +83 n03424325 gas mask or respirator +84 n02948072 candle +85 n03179701 desk +86 n03400231 frying pan +87 n02206856 bee +88 n03160309 dam +89 n01984695 spiny lobster +90 n03977966 police van +91 n03584254 iPod +92 n04023962 punching bag +93 n02814860 lighthouse +94 n01910747 jellyfish +95 n04596742 wok +96 n03992509 potter's wheel +97 n04133789 sandal +98 n03937543 pill bottle +99 n02927161 butcher shop +100 n01945685 slug +101 n02395406 pig +102 n02125311 cougar +103 n03126707 construction crane +104 n04532106 vestment +105 n02268443 dragonfly +106 n02977058 automated teller machine +107 n07734744 mushroom +108 n03599486 rickshaw +109 n04562935 water tower +110 n03014705 storage chest +111 n04251144 snorkel +112 n04356056 sunglasses +113 n02190166 fly +114 n03670208 limousine +115 n02002724 black stork +116 n02074367 dugong +117 n04285008 sports car +118 n04560804 water jug +119 n04366367 suspension bridge +120 n02403003 ox +121 n07615774 popsicle +122 n04501370 turnstile +123 n03026506 Christmas stocking +124 n02906734 broom +125 n01770393 scorpion +126 n04597913 wooden spoon +127 n03930313 picket fence +128 n04118538 rugby ball +129 n04179913 sewing machine +130 n04311004 through arch bridge +131 n02123394 Persian cat +132 n04070727 refrigerator +133 n02793495 barn +134 n02730930 apron +135 n02094433 Yorkshire Terrier +136 n04371430 swim trunks / shorts +137 n04328186 stopwatch +138 n03649909 lawn mower +139 n04417672 thatched roof +140 n03388043 fountain +141 n01774384 southern black widow +142 n02837789 bikini +143 n07579787 plate +144 n04399382 teddy bear +145 n02791270 barbershop +146 n03089624 candy store +147 n02814533 station wagon +148 n04149813 scoreboard +149 n07747607 orange +150 n03355925 flagpole +151 n01983481 American lobster +152 n04487081 trolleybus +153 n03250847 drumstick +154 n03255030 dumbbell +155 n02892201 brass memorial plaque +156 n02883205 bow tie +157 n03100240 convertible +158 n02415577 bighorn sheep +159 n02480495 orangutan +160 n01698640 American alligator +161 n01784675 centipede +162 n04376876 syringe +163 n03444034 go-kart +164 n01917289 brain coral +165 n01950731 sea slug +166 n03042490 cliff dwelling +167 n07711569 mashed potatoes +168 n04532670 viaduct +169 n03763968 military uniform +170 n07768694 pomegranate +171 n02999410 chain +172 n03617480 kimono +173 n06596364 comic book +174 n01768244 trilobite +175 n02410509 bison +176 n03976657 pole +177 n01742172 boa constrictor +178 n03980874 poncho +179 n02808440 bathtub +180 n02226429 grasshopper +181 n02231487 stick insect +182 n02085620 Chihuahua +183 n01644900 tailed frog +184 n02129165 lion +185 n02699494 altar +186 n03837869 obelisk +187 n02815834 beaker +188 n07720875 bell pepper +189 n02788148 baluster / handrail +190 n02909870 bucket +191 n03706229 magnetic compass +192 n07871810 meatloaf +193 n03447447 gondola +194 n02113799 Standard Poodle +195 n12267677 acorn +196 n03662601 lifeboat +197 n02841315 binoculars +198 n07715103 cauliflower +199 n02504458 African bush elephant diff --git a/core/model/InfLoRA.py b/core/model/InfLoRA.py new file mode 100644 index 0000000000000000000000000000000000000000..8668fc345a636622a3054e8e7ff6fb5401c30d50 --- /dev/null +++ b/core/model/InfLoRA.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{arXiv:2404.00228v3, + title = {InfLoRA: Interference-Free Low-Rank Adaptation for Continual Learning}, + author = {Yan-Shuo Liang and + Wu-Jun Li}, + booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, {CVPR} 2024, Seattle, Washington}, + publisher = {Computer Vision Foundation / {IEEE}}, + year = {2024}, + url = {https://arxiv.org/abs/2404.00228v3}, +} +https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html + +Adapted from https://github.com/liangyanshuo/InfLoRA?utm_source=catalyzex.com +""" + + +import torch +import torch.nn as nn +from torch import optim +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from torch.utils.data import DataLoader + +import logging +import numpy as np +from tqdm import tqdm +from sklearn.cluster import KMeans + +from .backbone.vit_inflora import Attention_LoRA +from copy import deepcopy +import math +from .finetune import Finetune + + +class InfLoRA(Finetune): + + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + + self._network = backbone + + for module in self._network.modules(): + if isinstance(module, Attention_LoRA): + module.init_param() + + # 100 categories in total, parameter passed assignment + self.num_class = num_class + # number of known and number of classes + self._total_classes =0 + # Number of categories known before this task, initially 0, updated in beforetask + self._known_classes =0 + + # The current task number, initially -1. +1 for each new task + self._cur_task = -1 + # number of tasks incremented each time + self.inc_cls_num = kwargs["inc_cls_num"] + + self.device = kwargs["device"] + + # These parameters are used in update DualGPM + self.feature_list = [] + self.project_type = [] + self.lame = kwargs["lame"] + self.lamb = kwargs["lamb"] + self.total_sessions = kwargs["total_sessions"] + + def observe(self, data): + ''' + Called during the training phase, it inputs a batch of training examples and returns the prediction, accuracy, and forward loss. + + Code Reference: + https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py + ''' + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + # Offset the target because the forward function in _network only predicts 0-9 + y = y-self._known_classes + logits = self._network(x)['logits'] + loss = F.cross_entropy(logits, y) + _, preds = torch.max(logits, dim=1) + correct = preds.eq(y.expand_as(preds)).cpu().sum() + total = len(y) + acc = correct/total + acc = acc.item() + return preds, acc, loss + + def inference(self, data): + ''' + It is called in the inference phase to input a batch of test samples and return the classification result and accuracy. + Calling the interface function of _network returns the value batchsize*_total_classes. + + Code Reference: + https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py + ''' + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logits = self._network.interface(x) + _, preds = torch.max(logits, dim=1) + correct = preds.eq(y.expand_as(preds)).cpu().sum() + total = len(y) + acc = correct/total + acc = acc.item() + return preds, acc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + It is called before the training of each task to update the parameters, select the branch for training, and update the lora_A matrix of the corresponding branch + + Code Reference: + https://github.com/gydpku/OCM/blob/main/test_cifar10.py + ''' + + # Update some variables + self._known_classes = self._total_classes + self._cur_task += 1 + self._total_classes = self._known_classes + self.inc_cls_num + self._network.update_fc(self._total_classes) + + self._network.to(self.device) + + # Freeze the model and only release the linear layer, and the lora_b layer corresponding to the task number to train + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + try: + if "classifier_pool" + "." + str(self._network.module.numtask - 1) in name: + param.requires_grad_(True) + if "lora_B_k" + "." + str(self._network.module.numtask - 1) in name: + param.requires_grad_(True) + if "lora_B_v" + "." + str(self._network.module.numtask - 1) in name: + param.requires_grad_(True) + except: + if "classifier_pool" + "." + str(self._network.numtask - 1) in name: + param.requires_grad_(True) + if "lora_B_k" + "." + str(self._network.numtask - 1) in name: + param.requires_grad_(True) + if "lora_B_v" + "." + str(self._network.numtask - 1) in name: + param.requires_grad_(True) + + # Check the layer to be trained + enabled = set() + for name, param in self._network.named_parameters(): + if param.requires_grad: + enabled.add(name) + + with torch.no_grad(): + # We run the trained data through the model in order to obtain the cur_matrix. This parameter is related to update_DualGPM + for batch_idx, batch in enumerate(train_loader): + inputs = batch["image"] + targets = batch["label"] + inputs, targets = inputs.to(self.device), targets.to(self.device) + inputs=F.interpolate(inputs, size=224, mode='bilinear', align_corners=False) + self._network(inputs, get_cur_feat=True) + + if self._cur_task == 0: + # Updating according to cur matrix requires A manually designed lora A + for module in self._network.modules(): + if isinstance(module, Attention_LoRA): + cur_matrix = module.cur_matrix + U, S, V = torch.linalg.svd(cur_matrix) + module.lora_A_k[self._cur_task].weight.data.copy_(U[:,:module.rank].T/math.sqrt(3)) + module.lora_A_v[self._cur_task].weight.data.copy_(U[:,:module.rank].T/math.sqrt(3)) + module.cur_matrix.zero_() + module.n_cur_matrix = 0 + else: + # Updating according to cur matrix requires A manually designed lora A + kk = 0 + for module in self._network.modules(): + if isinstance(module, Attention_LoRA): + cur_matrix = module.cur_matrix + if self.project_type[kk] == 'remove': + cur_matrix = cur_matrix - torch.mm(self.feature_mat[kk],cur_matrix) + else: + assert self.project_type[kk] == 'retain' + cur_matrix = torch.mm(self.feature_mat[kk],cur_matrix) + cU, cS, cV = torch.linalg.svd(cur_matrix, full_matrices=False) + module.lora_A_k[self._cur_task].weight.data.copy_(cU[:,:module.rank].T/math.sqrt(3)) + module.lora_A_v[self._cur_task].weight.data.copy_(cU[:,:module.rank].T/math.sqrt(3)) + module.cur_matrix.zero_() + module.n_cur_matrix = 0 + kk += 1 + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Called after each task starts training, it is used to perform preliminary operations on the mapping matrix to facilitate the update of lora_a layer in the next round of before_task + ''' + with torch.no_grad(): + # Get cur_matrix + for batch_idx, batch in enumerate(train_loader): + inputs = batch["image"] + targets = batch["label"] + inputs, targets = inputs.to(self.device), targets.to(self.device) + inputs=F.interpolate(inputs, size=224, mode='bilinear', align_corners=False) + self._network(inputs, get_cur_feat=True) + # Preliminary operations on the mapping matrix + mat_list = [] + for module in self._network.modules(): + if isinstance(module, Attention_LoRA): + mat_list.append(deepcopy(module.cur_matrix)) + module.cur_matrix.zero_() + module.n_cur_matrix = 0 + self.update_DualGPM(mat_list) + self.feature_mat = [] + for p in range(len(self.feature_list)): + Uf=torch.Tensor(np.dot(self.feature_list[p],self.feature_list[p].transpose())) + print('Layer {} - Projection Matrix shape: {}'.format(p+1,Uf.shape)) + self.feature_mat.append(Uf) + + return + + def update_DualGPM (self, mat_list): + ''' + Code Reference: + https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py + ''' + threshold = (self.lame - self.lamb)*self._cur_task/self.total_sessions + self.lamb + print ('Threshold: ', threshold) + if len(self.feature_list) == 0: + # After First Task + for i in range(len(mat_list)): + activation = mat_list[i] + U,S,Vh = np.linalg.svd(activation, full_matrices=False) + # criteria (Eq-5) + sval_total = (S**2).sum() + sval_ratio = (S**2)/sval_total + r = np.sum(np.cumsum(sval_ratio) Ui.shape[0] : + self.feature_list[i]=Ui[:,0:Ui.shape[0]] + else: + self.feature_list[i]=Ui + else: + assert self.project_type[i] == 'retain' + activation = mat_list[i] + U1,S1,Vh1=np.linalg.svd(activation, full_matrices=False) + sval_total = (S1**2).sum() + # Projected Representation (Eq-8) + act_hat = np.dot(np.dot(self.feature_list[i],self.feature_list[i].transpose()),activation) + U,S,Vh = np.linalg.svd(act_hat, full_matrices=False) + # criteria (Eq-9) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = sval_hat/sval_total + + r = 0 + for ii in range (sval_ratio.shape[0]): + if accumulated_sval >= (1-threshold): + accumulated_sval -= sval_ratio[ii] + r += 1 + else: + break + if r == 0: + print ('Skip Updating DualGPM for layer: {}'.format(i+1)) + continue + + # update GPM by Projected Representation (Eq-8) + act_feature = self.feature_list[i] - np.dot(np.dot(U[:,0:r],U[:,0:r].transpose()),self.feature_list[i]) + Ui, Si, Vi = np.linalg.svd(act_feature) + self.feature_list[i]=Ui[:,:self.feature_list[i].shape[1]-r] + + print('-'*40) + print('Gradient Constraints Summary') + print('-'*40) + for i in range(len(self.feature_list)): + if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): + feature = self.feature_list[i] + # ipdb.set_trace() + U, S, V = np.linalg.svd(feature) + new_feature = U[:,feature.shape[1]:] + self.feature_list[i] = new_feature + self.project_type[i] = 'retain' + elif self.project_type[i]=='retain': + assert self.feature_list[i].shape[1] <= (self.feature_list[i].shape[0]/2) + print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) + print('-'*40) + + def _set_random(self,args): + ''' + Set random values on various devices to ensure repeatable results + ''' + torch.manual_seed(args['seed']) + torch.cuda.manual_seed(args['seed']) + torch.cuda.manual_seed_all(args['seed']) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False \ No newline at end of file diff --git a/core/model/InfLoRA_opt.py b/core/model/InfLoRA_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e1b2e4f885652c805ffc548d82ec6328393d6c --- /dev/null +++ b/core/model/InfLoRA_opt.py @@ -0,0 +1,460 @@ +""" +@inproceedings{liang2024inflora, + title={InfLoRA: Interference-Free Low-Rank Adaptation for Continual Learning}, + author={Liang, Yan-Shuo and Li, Wu-Jun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={23638--23647}, + year={2024} +} + +Adapted from https://github.com/liangyanshuo/InfLoRA +""" + +import os +import math +import torch +import random +import torch.nn as nn +import numpy as np + +from torch import optim +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from tqdm import tqdm +from .backbone.transformer import MultiHeadAttention_LoRA, VisionTransformer +from .backbone.clip import CLIP, tokenize +from .backbone.vit import ViTZoo + +VIT = ViTZoo +CLIP = CLIP + +def _set_random(seed): + ''' + Set random values on various devices to ensure repeatable results + ''' + + seed = int(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +class SiNet(nn.Module): + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self._cur_task_id = -1 + self.backbone = backbone + self.device = device + + if isinstance(backbone, VIT): + _set_random(os.environ["PYTHONHASHSEED"]) + self.classifier_pool = nn.ModuleList([ + nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + + [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)] + ) + elif isinstance(backbone, CLIP): + self.accm_class_names = [] + self.curr_class_names = [] + self.accm_text_tokens = None + self.curr_text_tokens = None + + self.prompt_template = kwargs['prompt_template'] + else: + assert 0, f'Backbone not implemented' + + def update_fc(self, train_loader): + + self._cur_task_id += 1 + + if isinstance(self.backbone, CLIP): + + self.curr_class_names = train_loader.dataset.get_class_names() + self.accm_class_names += self.curr_class_names + + self.curr_text_tokens = tokenize( + [self.prompt_template.format(c) for c in self.curr_class_names] + ).to(self.device) + + self.accm_text_tokens = tokenize( + [self.prompt_template.format(c) for c in self.accm_class_names] + ).to(self.device) + + # These two for classifier alignment, + def get_feature(self, x): + if isinstance(self.backbone, VIT): + return self.backbone(x) + elif isinstance(self.backbone, CLIP): + assert 0 + else: + assert 0 + + def fc_only(self, x): + if isinstance(self.backbone, VIT): + logits = [] + for prompts in self.classifier_pool[:self._cur_task_id + 1]: + logits.append(prompts(x)) + return torch.cat(logits, dim=1) + elif isinstance(self.backbone, CLIP): + assert 0 + else: + assert 0 + + def forward(self, x, inference = False): + + if isinstance(self.backbone, VIT): + + logits = [] + features = self.backbone(x) + + if inference: + for prompts in self.classifier_pool[:self._cur_task_id + 1]: + logits.append(prompts(features)) + else: + for prompts in [self.classifier_pool[self._cur_task_id]]: + logits.append(prompts(features)) + + return torch.cat(logits, dim=1) + + elif isinstance(self.backbone, CLIP): + if inference: + features_img, features_txt, logits_per_img, logits_per_txt = self.backbone(x, self.accm_text_tokens) + else: + features_img, features_txt, logits_per_img, logits_per_txt = self.backbone(x, self.curr_text_tokens) + return logits_per_img + else: + assert 0, f'Backbone not implemented' + + def update_input_matrix(self, x): + + if isinstance(self.backbone, VIT): + self.backbone(x, get_input_matrix = True) + + elif isinstance(self.backbone, CLIP): + self.backbone(image = x, text = self.curr_text_tokens, get_input_matrix = True) + +class InfLoRA_OPT(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.device = device + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.task_num = kwargs["task_num"] + self.lame = kwargs["lame"] + self.lamb = kwargs["lamb"] + + self._known_classes = 0 + self.feature_list = [] + self.project_type = [] + + self._dataset = kwargs['dataset'] + self._use_class_alignment = kwargs['use_ca'] + self._logit_norm = None if self._dataset == 'cifar100' else 0.1 + self._class_means = None + self._class_covs = None + + self._network = SiNet(backbone, device, **kwargs).to(self.device) + + if isinstance(backbone, VIT): + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_LoRA)] + elif isinstance(backbone, CLIP): + self.visual_only = kwargs['visual_only'] + if self.visual_only: + self.attention_modules = [module for name, module in self._network.named_modules() if isinstance(module, MultiHeadAttention_LoRA) and 'visual' in name] + else: + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_LoRA)] + else: + assert 0, 'Not Implmented' + + def observe(self, data): + ''' + Called during the training phase, it inputs a batch of training examples and returns the prediction, accuracy, and forward loss. + ''' + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + logits = self._network(x) + loss = F.cross_entropy(logits, y) + + preds = logits.max(1)[1] + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc, loss + + def inference(self, data): + ''' + It is called in the inference phase to input a batch of test samples and return the classification result and accuracy. + Calling the interface function of _network returns the value batchsize*_total_classes. + ''' + + x, y = data['image'].to(self.device), data['label'].to(self.device) + logits = self._network(x, inference = True) + preds = logits.max(1)[1] + + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + It is called before the training of each task to update the parameters, select the branch for training, and update the lora_A matrix of the corresponding branch + ''' + + if task_idx == 1: + self._known_classes = self.init_cls_num + elif task_idx > 1: + self._known_classes += self.inc_cls_num + self._network.update_fc(train_loader) + + _set_random(os.environ["PYTHONHASHSEED"]) + for module in self.attention_modules: + module.init_param() + + unfrezeed_params = [] + if isinstance(self._network.backbone, VIT): + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if f"classifier_pool.{task_idx}." in name or "lora_B" in name: + param.requires_grad_(True) + unfrezeed_params.append(name) + elif isinstance(self._network.backbone, CLIP): + if self.visual_only: + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if "visual" in name and "lora_B" in name: + param.requires_grad_(True) + unfrezeed_params.append(name) + else: + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if "lora_B" in name: + param.requires_grad_(True) + unfrezeed_params.append(name) + + print(f"Current task : {task_idx}, Parameters to be updated: {len(unfrezeed_params)}") + print(",\n".join(unfrezeed_params)) + + _set_random(os.environ["PYTHONHASHSEED"]) + for batch in tqdm(train_loader, desc="Forwarding to get input matrix"): + self._network.update_input_matrix(x = batch['image'].to(self.device)) + + + if task_idx == 0: + for module in self.attention_modules: + assert module.n_cur_matrix > 0 + U, S, _ = torch.linalg.svd(module.cur_matrix, full_matrices=False) + + module.lora_A_k.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.lora_A_v.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.reset_input_matrix() + else: + for i, module in enumerate(self.attention_modules): + assert self.project_type[i] == 'remove' or self.project_type[i] == 'retain' + + cur_matrix = module.cur_matrix + feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) + + if self.project_type[i] == 'remove': + cur_matrix = cur_matrix - feature_mat @ cur_matrix + else: + cur_matrix = feature_mat @ cur_matrix + + U, _, _ = torch.linalg.svd(cur_matrix, full_matrices = False) + module.lora_A_k.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.lora_A_v.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.reset_input_matrix() + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Called after each task before final testing, it is used to perform preliminary operations on the mapping matrix to facilitate the update of lora_a layer in the next round of before_task + ''' + + for module in self.attention_modules: + module.merge_weight() + + self._update_feature(task_idx, train_loader, test_loaders[0].dataset.trfms) + if self._use_class_alignment: + self._create_distribution(train_loader, test_loaders[0].dataset.trfms) + if task_idx > 0: + self._compact_classifier(task_idx) + + @torch.no_grad() + def _update_feature(self, task_idx, train_loader, test_trfms): + ''' + Update feature lists and the corresponding type + ''' + + _set_random(os.environ["PYTHONHASHSEED"]) + for batch in tqdm(train_loader, desc="Forwarding to get input matrix"): + + self._network.update_input_matrix(x = batch['image'].to(self.device)) + + threshold = (self.lame - self.lamb)*task_idx/self.task_num + self.lamb + + if task_idx == 0: + for i, attention_module in enumerate(self.attention_modules): + activation = attention_module.cur_matrix + + U, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_total = (S**2).sum() + sval_ratio = (S**2)/sval_total + r = max(np.sum(np.cumsum(sval_ratio) < threshold), 1) + assert r < activation.shape[0]/2 + + self.feature_list.append(U[:, :r]) + self.project_type.append('remove') + + attention_module.reset_input_matrix() + else: + for i, attention_module in enumerate(self.attention_modules): + + activation = attention_module.cur_matrix + _, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_total = (S**2).sum() + + if self.project_type[i] == 'remove': + + act_hat = activation - torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) @ activation + U, S, _ = np.linalg.svd(act_hat, full_matrices = False) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = (sval_total-sval_hat)/sval_total + + if accumulated_sval >= threshold: + print (f'Skip Updating DualGPM for layer: {i+1}') + else: + r = np.sum(np.cumsum(sval_ratio) + accumulated_sval < threshold) + 1 + Ui = np.hstack((self.feature_list[i], U[:, :r])) + self.feature_list[i] = Ui[:, :min(Ui.shape[0], Ui.shape[1])] + + else: + act_hat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) @ activation + U,S,_ = np.linalg.svd(act_hat, full_matrices = False) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = sval_hat/sval_total + + if accumulated_sval < 1 - threshold: + print (f'Skip Updating Space for layer: {i+1}') + else: + r = np.sum(accumulated_sval - np.cumsum(sval_ratio) >= 1 - threshold) + 1 + act_feature = self.feature_list[i] - U[:,0:r] @ U[:,0:r].T @ self.feature_list[i] + U, _, _ = np.linalg.svd(act_feature) + self.feature_list[i]=U[:,:self.feature_list[i].shape[1]-r] + + attention_module.reset_input_matrix() + + print('-'*40) + print(f'Threshold: {threshold}') + print('-'*40) + for i in range(len(self.feature_list)): + if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): + feature = self.feature_list[i] + U, S, V = np.linalg.svd(feature) + new_feature = U[:,feature.shape[1]:] + self.feature_list[i] = new_feature + self.project_type[i] = 'retain' + elif self.project_type[i]=='retain': + assert self.feature_list[i].shape[1] <= (self.feature_list[i].shape[0]/2) + print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) + print('-'*40) + + @torch.no_grad() + def _create_distribution(self, train_loader, test_trfms): + + self._network.eval() + train_loader.dataset.trfms = test_trfms + + samples = [[] for _ in range(self.inc_cls_num)] + for batch in train_loader: + x, y = batch['image'], batch['label'] - self._known_classes + for label in range(self.inc_cls_num): + samples[label].append(x[y == label]) + samples = [torch.cat(label_sample, dim = 0).to(self.device) for label_sample in samples] + + # Computing class mean + if self._class_means is None: + self._class_means = torch.zeros((self.init_cls_num, 768)) + self._class_covs = torch.zeros((self.init_cls_num, 768, 768)) + else: + self._class_means = torch.cat((self._class_means, torch.zeros((self.inc_cls_num, 768))), dim=0) + self._class_covs = torch.cat((self._class_covs, torch.zeros((self.inc_cls_num, 768, 768))), dim=0) + + for class_idx, x in enumerate(samples): + class_idx += self._known_classes + features = self._network.get_feature(x) + + self._class_means[class_idx, :] = torch.mean(features, dim = 0) + self._class_covs[class_idx, :, :] = torch.cov(features.to(torch.float64).T) + torch.eye(768, device = self.device) * 1e-4 + + def _compact_classifier(self, task_idx): + + # Hyperparam + epoch = 5 + lr = 0.01 + weight_decay = 0.0005 + momentum = 0.9 + num_sample = 256 + + for param in self._network.classifier_pool[:task_idx + 1].parameters(): + param.requires_grad_(True) + param_list = [param for param in self._network.classifier_pool.parameters() if param.requires_grad] + + optimizer = optim.SGD(param_list, lr=lr, momentum=momentum, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epoch) + + for ep in range(epoch): + sampled_data, sampled_label = [], [] + + for class_id in range((task_idx + 1) * self.inc_cls_num): + task_id = class_id // self.inc_cls_num + + decay = (task_id + 1) / (task_idx + 1) * 0.1 + cls_mean = self._class_means[class_id].to(self.device, torch.float64) * (0.9 + decay) + cls_cov = self._class_covs[class_id].to(self.device) + + m = torch.distributions.multivariate_normal.MultivariateNormal(cls_mean.float(), cls_cov.float()) + + sampled_data_single = m.sample(sample_shape=(num_sample,)) + sampled_data.append(sampled_data_single) + sampled_label.extend([class_id] * num_sample) + + inputs = torch.cat(sampled_data, dim=0).float().to(self.device) + targets = torch.tensor(sampled_label).long().to(self.device) + + # Randomize + sf_indexes = torch.randperm(inputs.size(0)) + inputs = inputs[sf_indexes] + targets = targets[sf_indexes] + + for _iter in range((task_idx + 1) * self.inc_cls_num): + + inp = inputs[_iter * num_sample : (_iter+1) * num_sample] + tgt = targets[_iter * num_sample : (_iter+1) * num_sample] + logits = self._network.fc_only(inp) + + if self._logit_norm: + + pass + + else: + loss = F.cross_entropy(logits, tgt) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + scheduler.step() + + def get_parameters(self, config): + return self._network.parameters() + diff --git a/core/model/MInfLoRA.py b/core/model/MInfLoRA.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ed9d8594845e94246e56f7aa81e40a89d997e4 --- /dev/null +++ b/core/model/MInfLoRA.py @@ -0,0 +1,627 @@ +""" +Code Reference: +https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py +""" + +import os +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np +import matplotlib.pyplot as plt + +from tqdm import tqdm +from .backbone.transformer import MultiHeadAttention_MaskedLoRA1 + +GREEDY=True +APPROX_FEAT=True + +Epsilon = 0.5 + +def _set_random(seed): + ''' + Set random values on various devices to ensure repeatable results + ''' + + seed = int(seed) + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def select_probe_greedy_span_unified_normalized(cur_matrixs_list, energy_threshold=0.95, top_r=None): + """ + Greedy span selection across multiple attention blocks, with per-block normalization. + Dynamically select samples that together span a certain percentage of gradient space. + + Args: + cur_matrixs_list (List[torch.Tensor]): list of (Num, 768, 768) tensors for each block. + energy_threshold (float): fraction of gradient space to cover. + top_r (int, optional): number of top singular vectors to use. + + Returns: + selected_indices (torch.Tensor) + """ + N = cur_matrixs_list[0].shape[0] + device = cur_matrixs_list[0].device + + # 1. Normalize each block independently + normalized_cur_matrices = [] + for matrices in cur_matrixs_list: + frob_norms = torch.norm(matrices.view(N, -1), dim=-1, p=2).view(N, 1, 1) # (N, 1, 1) + matrices_normalized = matrices / (frob_norms + 1e-8) + normalized_cur_matrices.append(matrices_normalized) + + # 2. Compute global covariance C + C_global = sum([matrices.sum(dim=0) for matrices in normalized_cur_matrices]) + + # 3. SVD on global C + U, _, _ = torch.linalg.svd(C_global) + if top_r is not None: + U = U[:, :top_r] + + # 4. Compute projected sample vectors + projected_vectors = [] + for i in range(N): + x_cov_sum = sum([matrices[i] for matrices in normalized_cur_matrices]) # (768, 768) + proj = U.T @ x_cov_sum @ U # (top_r, top_r) + projected_vectors.append(proj.flatten()) # flatten to (top_r*top_r,) + projected_vectors = torch.stack(projected_vectors, dim=0) # (N, top_r*top_r) + + # 5. Greedy selection + selected_indices = [] + remaining_indices = set(range(N)) + selected_vectors = [] + + + + #total_energy = projected_vectors.norm(dim=-1).pow(2).sum().item() + current_energy = 0.0 + + while current_energy / total_energy < energy_threshold: + best_idx = -1 + best_gain = -float('inf') + + for idx in remaining_indices: + vec = projected_vectors[idx] # (top_r*top_r,) + if selected_vectors: + # Project onto orthogonal complement + selected_mat = torch.stack(selected_vectors, dim=0) # (num_selected, D) + + Q, _ = torch.linalg.qr(selected_mat.T, mode='reduced') # (D, k) + projection = (Q @ (Q.T @ vec)) # (D,) + + #projection = (vec @ selected_mat.T) @ selected_mat # (D,) + vec_residual = vec - projection + else: + vec_residual = vec + + gain = vec_residual.norm().item() + + if gain > best_gain: + best_gain = gain + best_idx = idx + + selected_indices.append(best_idx) + remaining_indices.remove(best_idx) + selected_vectors.append(projected_vectors[best_idx] / (projected_vectors[best_idx].norm() + 1e-8)) + current_energy += best_gain ** 2 + + selected_indices = torch.tensor(selected_indices) + + print(f"Selected {len(selected_indices)} samples covering {current_energy / total_energy * 100:.2f}% of gradient space.") + + # 7. Optional plotting + plt.figure(figsize=(8, 6)) + plt.plot(torch.arange(len(selected_indices))+1, (torch.tensor([current_energy / total_energy for _ in selected_indices])*100).numpy(), label='Cumulative Span Coverage') + plt.xlabel('Number of Samples Selected') + plt.ylabel('Coverage (%)') + plt.title('Greedy Span Selection Coverage') + plt.grid(True) + plt.legend() + plt.savefig('greedy_span_coverage.png', dpi=300) + + return selected_indices + +def select_probe_greedy_span_unified_normalized_high_precision( + cur_matrixs_list, + energy_threshold=0.95, + top_r=None +): + """ + Greedy span selection across multiple attention blocks, with per-block normalization. + Dynamically select samples that together span a certain percentage of gradient space. + + Args: + cur_matrixs_list (List[torch.Tensor]): list of (Num, 768, 768) tensors for each block. + energy_threshold (float): fraction of gradient space to cover. + top_r (int, optional): number of top singular vectors to use. + feature_mode (str): "trace" (default) or "flatten". How to extract projected features. + + Returns: + selected_indices (torch.Tensor) + """ + + N = cur_matrixs_list[0].shape[0] + + # 1. Normalize each block independently + normalized_cur_matrices = [] + for matrices in cur_matrixs_list: + frob_norms = torch.norm(matrices.view(N, -1), dim=-1, p=2).view(N, 1, 1) # (N, 1, 1) + matrices_normalized = matrices / (frob_norms + 1e-8) + normalized_cur_matrices.append(matrices_normalized) + + # 2. Compute global covariance C + C_global = sum([matrices.sum(dim=0) for matrices in normalized_cur_matrices]) # (768, 768) + + # 3. SVD on global C + U, _, _ = torch.linalg.svd(C_global) + if top_r is not None: + U = U[:, :top_r] # (768, top_r) + + # 4. Compute projected sample vectors + projected_vectors = [] + for i in range(N): + x_cov_sum = sum([matrices[i] for matrices in normalized_cur_matrices]) # (768, 768) + proj = U.T @ x_cov_sum @ U # (top_r, top_r) + proj_feat = proj.flatten() # (top_r*top_r,) + projected_vectors.append(proj_feat) + projected_vectors = torch.stack(projected_vectors, dim=0) # (N, D) + + # 5. Greedy selection with orthogonal residual updates + selected_indices = [] + remaining_indices = set(range(N)) + + residual_vectors = projected_vectors.clone() # (N, D) + selected_vectors = [] + + total_energy = projected_vectors.norm(dim=-1).pow(2).sum().item() + current_energy = 0.0 + + # 假设 selected_vectors 已经初始化 + while current_energy / total_energy < energy_threshold: + assert remaining_indices + best_idx = -1 + best_gain = -float('inf') + + for idx in remaining_indices: + vec = residual_vectors[idx] + if selected_vectors: + # 计算当前向量与已选向量的正交残差 + selected_mat = torch.stack(selected_vectors, dim=0) # (num_selected, D) + Q, _ = torch.linalg.qr(selected_mat.T, mode='reduced') # (D, k) + vec_residual = vec - (Q @ (Q.T @ vec)) + else: + vec_residual = vec + + # 当前样本的能量 + gain = vec_residual.norm().item() ** 2 + + if gain > best_gain: + best_gain = gain + best_idx = idx + if not GREEDY: + break + + # 累计选择的样本能量 + selected_indices.append(best_idx) + selected_vec = residual_vectors[best_idx] / (residual_vectors[best_idx].norm() + 1e-8) + current_energy += projected_vectors[best_idx].norm().item() ** 2 + print(current_energy, '/', total_energy) + + # 更新 selected_vectors 和 residual_vectors + selected_vectors.append(selected_vec) + projection = (residual_vectors @ selected_vec.unsqueeze(-1)).squeeze(-1) + residual_vectors = residual_vectors - projection.unsqueeze(-1) * selected_vec.unsqueeze(0) + + remaining_indices.discard(best_idx) + + # 输出最终的选择结果 + selected_indices = torch.tensor(selected_indices) + print(f"Selected {len(selected_indices)} samples covering {current_energy / total_energy * 100:.2f}% of gradient space.") + + return selected_indices + +# ------ + +class TopK: + + ''' + A class to maintain a collection of the top K items based on a specified attribute. + + This class allows for the dynamic addition of items, each represented as a dictionary, + where each dictionary must have a key 'proj_norm' that represents the value used + to determine the ranking. The class keeps track of the top K items with the highest + 'proj_norm' values. + ''' + + def __init__(self, k): + self.k = k + self.top_k_list = [] + + def add(self, dict): + if len(self.top_k_list) < self.k: + self.top_k_list.append(dict) + elif dict['proj_norm'] > min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: + self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['proj_norm'])) + self.top_k_list.append(dict) + elif dict['proj_norm'] == min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm'] and \ + dict['proj_norm'] == max(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: + self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['task_id'])) + self.top_k_list.append(dict) + + def get_top_k(self): + return self.top_k_list + +class SiNet(nn.Module): + def __init__(self, backbone, **kwargs): + super().__init__() + + self._cur_task_id = -1 + self.backbone = backbone + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + + _set_random(os.environ["PYTHONHASHSEED"]) + self.classifier_pool = nn.ModuleList([ + nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + + [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)]) + + for name, module in self.backbone.named_modules(): + if 'transformer' in name and 'blocks' not in name: + self.transformer_module = module + + def update_fc(self): + self._cur_task_id += 1 + + def forward(self, x, expert_id, inference = False): + logits = [] + features = self.backbone(x, expert_id = expert_id) + + if inference: + + # Bayesian + for i, prompts in enumerate(self.classifier_pool[:self._cur_task_id + 1]): + # No Masking + logits.append(prompts(features)) + + logits = torch.cat(logits, dim=1) + + return logits + + else: + logits.append(self.classifier_pool[self._cur_task_id](features)) + return torch.cat(logits, dim=1) + + def update_input_matrix(self, x): + self.backbone(x, expert_id = -1, get_input_matrix = True) + +class MInfLoRA(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.device = device + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.task_num = kwargs["task_num"] + self.lame = kwargs["lame"] + self.lamb = kwargs["lamb"] + self.embd_dim = kwargs["embd_dim"] + self.eval_mat = False + + self._known_classes = 0 + self.feature_list = [] + self.project_type = [] + + self.distributed = torch.distributed.is_initialized() + assert not self.distributed, 'current not support' + self.local_rank = torch.distributed.get_rank() if self.distributed else 0 + + self._network = SiNet(backbone, **kwargs) + + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_MaskedLoRA1)] + + # TRGP Implementation + self.feature_list_each_tasks = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + self.final_decision = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + self.before_mat = [[0 for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + + self.experts_distributions = [] + + # Class Alignment Implementation + self._use_class_alignment = kwargs['use_ca'] + self._class_means = None + self._class_covs = None + self._dataset = kwargs['dataset'] + if self._dataset == 'cifar': + self.logit_norm = None + else: + self.logit_norm = 0.1 + + self.lll = [] + + self._network.to(self.device) + + def observe(self, data): + + with torch.no_grad(): + self._network(self.probe_selection, expert_id = -1) + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + logits = self._network(x, expert_id = self._network._cur_task_id) + loss = F.cross_entropy(logits, y) + + preds = logits.max(1)[1] + acc = preds.eq(y).sum().item() / y.shape[0] + + return preds, acc, loss + + def inference(self, data, **kwargs): + + task_id = kwargs['task_id'] if 'task_id' in kwargs else -1 + x, y = data['image'].to(self.device, non_blocking=True), data['label'].to(self.device, non_blocking=True) + + logits = self._network(x, expert_id = task_id, inference = True) + preds = logits.max(1)[1] + acc = preds.eq(y).sum().item() / y.shape[0] + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + print('Greedy', GREEDY) # current best is not greedy, yes approx feature + print('Approx Feature', APPROX_FEAT) + + self._network.update_fc() + + # mag = nn.ParameterList([nn.Parameter(torch.Tensor([1.0])) for _ in range(task_idx + 1)]) + _set_random(os.environ["PYTHONHASHSEED"]) + for module in self.attention_modules: + #module.mag_lora = mag + module.init_param() + + self._network = self._network.to(self.device) + self._update_input_matrix(train_loader) + + ''' + probe_indices_svd = select_probe_svd_energy_matrix_unified_normalized( + [m.cur_matrixs for m in self.attention_modules] + ,probe_size=512 + ) + ''' + ''' + self.probe_indices_svd = select_probe_svd_energy_matrix_unified_normalized( + [m.cur_matrixs for m in self.attention_modules] + ,energy_threshold=0.15, top_r=64 + ) + ''' + self.probe_indices_svd = select_probe_greedy_span_unified_normalized_high_precision( + [m.cur_matrixs for m in self.attention_modules] + ,energy_threshold=0.01, top_r=128 + #,energy_threshold=0.5, top_r=128 + ) + + self.probe_selection = self.dataset[self.probe_indices_svd].to(self.device) + + if task_idx == 0: + for i, module in enumerate(self.attention_modules): + + # Either divide with 512 or divice with 512 * 197 + U, _, _ = torch.linalg.svd(module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512, full_matrices=False) + + module.lora_A_k_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.lora_A_v_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + + else: + for i, module in enumerate(self.attention_modules): + + feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) + module.feature_mat = feature_mat.clone().to(self.device) + + activation = module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512 + activation = activation - feature_mat @ activation + + U, _, _ = torch.linalg.svd(activation, full_matrices = False) + + module.lora_A_k_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/(3 ** 0.5)) + module.lora_A_v_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/(3 ** 0.5)) + + ''' + for i, module in enumerate(self.attention_modules): + + topk = TopK(1) + + mat = module.cur_matrix.cpu().numpy() + mat_norm = np.linalg.norm(mat) + + for task_id in range(task_idx): + + if not np.array_equal(self.feature_list_each_tasks[task_id][i], np.zeros((1))): + + proj_norm = np.linalg.norm(self.feature_list_each_tasks[task_id][i] @ self.feature_list_each_tasks[task_id][i].T @ mat) + print(f'{task_idx} to {task_id} in layer {i} : {proj_norm}') + + if proj_norm > Epsilon * mat_norm: + topk.add({'proj_norm':proj_norm, 'task_id': task_id}) + + self.final_decision[task_idx][i] = [dic['task_id'] for dic in topk.get_top_k()] + print(f'Layer {i} of {task_idx} consider {self.final_decision[task_idx][i]} as trust region') + + self.prev_matrix = [] + if task_idx == 0: + for i, module in enumerate(self.attention_modules): + + U, _, _ = torch.linalg.svd(module.cur_matrix) + U = torch.Tensor(U).to(self.device) + + self.prev_matrix.append(U[:,:module.lora_rank].T.cpu()) + + module.lora_A_k_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.lora_A_v_list[task_idx].weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + #module.reset_input_matrix() + else: + for i, module in enumerate(self.attention_modules): + assert self.project_type[i] == 'remove' or self.project_type[i] == 'retain' + + cur_matrix = module.cur_matrix.to(self.device) + + + # TRGP + tr = self.final_decision[task_idx][i][0] + tr = task_idx - 1 + + #feature_mat = torch.Tensor(self.feature_list_each_tasks[tr][i] @ self.feature_list_each_tasks[tr][i].T).to(self.device) + + feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T).to(self.device) + intersect = feature_mat @ cur_matrix + + target_shape = 768 + + U, _, _ = np.linalg.svd(intersect.cpu().numpy(), full_matrices = False) + U = torch.Tensor(U).to(self.device) + module.space_k[tr] = U[:, :target_shape].T/math.sqrt(3) + module.space_v[tr] = U[:, :target_shape].T/math.sqrt(3) + + # InfLoRA + feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T).to(self.device) + + if self.project_type[i] == 'remove': + cur_matrix = cur_matrix - feature_mat @ cur_matrix + else: + cur_matrix = feature_mat @ cur_matrix + + module.feature_mat = feature_mat.clone() + + U, _, _ = np.linalg.svd(cur_matrix.cpu().numpy(), full_matrices = False) + U = U[:, :module.lora_rank] + + alphas = torch.linalg.lstsq(torch.Tensor(module.lora_A_k_list[task_idx-1].weight.data).T.cpu(), torch.Tensor(U) / math.sqrt(3)) + if alphas.residuals.numel() != 0: + print(f'Task {task_idx}, Layer {i}, {alphas.residuals}') + assert 0 + + U = torch.Tensor(U).to(self.device) + + module.lora_A_k_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/math.sqrt(3)) # here should have /sqrt3 + module.lora_A_v_list[task_idx].weight.data.copy_(U[:, :module.lora_rank].T/math.sqrt(3)) + ''' + + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if f"classifier_pool.{task_idx}" in name or \ + f"lora_B_k_list.{task_idx}" in name or \ + f"lora_B_v_list.{task_idx}" in name: + param.requires_grad_(True) + + for name, param in self._network.named_parameters(): + if param.requires_grad: + print(name) + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Called after each task before final testing, it is used to perform preliminary operations on the mapping matrix to facilitate the update of lora_a layer in the next round of before_task + ''' + + self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + self._update_feature(task_idx, train_loader, test_loaders) + + @torch.no_grad() + def _update_feature(self, task_idx, train_loader, test_loaders): + ''' + Update feature lists and the corresponding type + ''' + + self._update_input_matrix(train_loader) + + if self.local_rank == 0: + + threshold = (self.lame - self.lamb)*task_idx/self.task_num + self.lamb + + if task_idx == 0: + for i, module in enumerate(self.attention_modules): + + activation = module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512 + U, S, _ = torch.linalg.svd(activation, full_matrices=False) + true_U = U[:, :module.lora_rank] + + # Least Square + alphas = torch.linalg.lstsq(module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3), true_U) + approx2_U = module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3) @ alphas.solution + + if APPROX_FEAT: + self.feature_list.append(approx2_U) + self.feature_list_each_tasks[task_idx][i] = approx2_U + else: + self.feature_list.append(true_U) + self.feature_list_each_tasks[task_idx][i] = true_U + + self.project_type.append('remove') + + else: + for i, module in enumerate(self.attention_modules): + + activation = module.cur_matrixs[self.probe_indices_svd].sum(dim=0) / 512 + act_hat = activation - torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) @ activation + + U, _, _ = torch.linalg.svd(act_hat, full_matrices = False) + true_U = U[:, :module.lora_rank] + + alphas = torch.linalg.lstsq(module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3), true_U) + approx2_U = module.lora_A_k_list[task_idx].weight.data.T.cpu() * math.sqrt(3) @ alphas.solution + + if APPROX_FEAT: + self.feature_list[i] = torch.cat([self.feature_list[i], approx2_U], dim=1) + self.feature_list_each_tasks[task_idx][i] = approx2_U + else: + self.feature_list[i] = torch.cat([self.feature_list[i], true_U], dim=1) + self.feature_list_each_tasks[task_idx][i] = true_U + + print('-'*40) + print(f'Threshold: {threshold}') + print('-'*40) + for i in range(len(self.feature_list)): + ''' + if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): + feature = self.feature_list[i] + U, S, V = np.linalg.svd(feature) + new_feature = U[:,feature.shape[1]:] + self.feature_list[i] = new_feature + self.project_type[i] = 'retain' + elif self.project_type[i]=='retain': + assert self.feature_list[i].shape[1] <= (self.feature_list[i].shape[0]/2) + ''' + print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) + print('-'*40) + + @torch.no_grad() + def _update_input_matrix(self, train_loader): + + for module in self.attention_modules: + module.reset_input_matrix() + + _set_random(os.environ["PYTHONHASHSEED"]) # consistency + self.dataset = [] + for batch in tqdm(train_loader, desc="Forwarding to get input matrix", disable=(self.local_rank != 0)): + self._network.update_input_matrix(batch['image'].to(self.device)) + self.dataset.append(batch['image']) + + self.dataset = torch.cat(self.dataset, dim=0) + + for module in self.attention_modules: + module.cur_matrixs = torch.cat(module.cur_matrixs, dim=0) + module.cur_matrixs = torch.bmm( + module.cur_matrixs.permute(0, 2, 1), + module.cur_matrixs + ).cpu() + + def get_parameters(self, config): + return self._network.parameters() \ No newline at end of file diff --git a/core/model/MInfLoRA2.py b/core/model/MInfLoRA2.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d7fbc61baee020f2df127d45f778fdd19adc43 --- /dev/null +++ b/core/model/MInfLoRA2.py @@ -0,0 +1,390 @@ +""" +Code Reference: +https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py +""" +import os +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torch import optim +from torch.nn.parameter import Parameter +from tqdm import tqdm +from math import pi +from torchvision import transforms + +from .backbone.transformer import MultiHeadAttention_MultiMaskedLoRA + +Epsilon = 0.5 + +class TopK: + + ''' + A class to maintain a collection of the top K items based on a specified attribute. + + This class allows for the dynamic addition of items, each represented as a dictionary, + where each dictionary must have a key 'proj_norm' that represents the value used + to determine the ranking. The class keeps track of the top K items with the highest + 'proj_norm' values. + ''' + + def __init__(self, k): + self.k = k + self.top_k_list = [] + + def add(self, dict): + if len(self.top_k_list) < self.k: + self.top_k_list.append(dict) + elif dict['proj_norm'] > min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: + self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['proj_norm'])) + self.top_k_list.append(dict) + elif dict['proj_norm'] == min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm'] and \ + dict['proj_norm'] == max(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: + self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['task_id'])) + self.top_k_list.append(dict) + + def get_top_k(self): + return self.top_k_list + +class SiNet(nn.Module): + def __init__(self, backbone, **kwargs): + super().__init__() + + self._cur_task_id = -1 + self.backbone = backbone + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + + self.classifier_pool = nn.ModuleList([ + nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + + [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)]) + + for name, module in self.backbone.named_modules(): + if 'transformer' in name and 'blocks' not in name: + self.transformer_module = module + + def update_fc(self): + self._cur_task_id += 1 + + def fc_only(self, x, expert_id): + logits = [] + for prompts in self.classifier_pool[:expert_id + 1]: + logits.append(prompts(x)) + return torch.cat(logits, dim=1) + + def fc_only2(self, x): + logits = [] + for prompts in self.classifier_pool[:self._cur_task_id + 1]: + logits.append(prompts(x)) + return torch.cat(logits, dim=1) + + def get_feature(self, x, expert_id): + features = self.backbone(x, expert_id = expert_id) + return features + + def forward(self, x, expert_id, inference = False): + logits = [] + features = self.backbone(x, expert_id = expert_id) + + if inference: + + probs = self.transformer_module.probs + probs = torch.Tensor(probs[-1]).to(x.device) # consider only last layer + + # Bayesian + for i, prompts in enumerate(self.classifier_pool[:self._cur_task_id + 1]): + logits.append(prompts(features)) + + logits = torch.cat(logits, dim=1) + + return logits + + else: + logits.append(self.classifier_pool[self._cur_task_id](features)) + return torch.cat(logits, dim=1) + + def update_input_matrix(self, x): + self.backbone(x, expert_id = 0, get_input_matrix = True) + +class MInfLoRA2(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.device = device + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.task_num = kwargs["task_num"] + self.lame = kwargs["lame"] + self.lamb = kwargs["lamb"] + self.eval_mat = kwargs['eval_mat'] + + self._known_classes = 0 + self.feature_list = [] + self.project_type = [] + + self._network = SiNet(backbone, **kwargs) + + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_MultiMaskedLoRA)] + + # TRGP Implementation + self.feature_list_each_tasks = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + self.final_decision = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + self.before_mat = [[0 for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + + self.experts_distributions = [] + + # Class Alignment Implementation + self._use_class_alignment = kwargs['use_ca'] + self._class_means = None + self._class_covs = None + self._dataset = kwargs['dataset'] + if self._dataset == 'cifar': + self.logit_norm = None + else: + self.logit_norm = 0.1 + + self.lll = [] + + self._network.to(self.device) + + def observe(self, data): + ''' + Called during the training phase, it inputs a batch of training examples and returns the prediction, accuracy, and forward loss. + ''' + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + logits = self._network(x, expert_id = self._network._cur_task_id) # hardcoded for task_id + loss = F.cross_entropy(logits, y) + + preds = logits.max(1)[1] + acc = preds.eq(y).sum().item() / y.shape[0] + + return preds, acc, loss + + def inference(self, data, **kwargs): + + task_id = kwargs['task_id'] if 'task_id' in kwargs else None + x, y = data['image'].to(self.device), data['label'].to(self.device) + + logits = self._network(x, expert_id = 0, inference = True) + preds = logits.max(1)[1] + acc = preds.eq(y).sum().item() / y.shape[0] + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + if task_idx == 1: + self._known_classes += self.init_cls_num + elif task_idx > 1: + self._known_classes += self.inc_cls_num + self._network.update_fc() + + for module in self.attention_modules: + module.init_param() + + self._update_input_matrix(train_loader, test_loaders[0].dataset.trfms) + + for i, module in enumerate(self.attention_modules): + + topk = TopK(1) + + mat = module.cur_matrix.cpu().numpy() + mat_norm = np.linalg.norm(mat) + + for task_id in range(task_idx): + + proj_norm = np.linalg.norm(self.feature_list_each_tasks[task_id][i] @ self.feature_list_each_tasks[task_id][i].T @ mat) + + if proj_norm > Epsilon * mat_norm: + topk.add({'proj_norm':proj_norm, 'task_id': task_id}) + + self.final_decision[task_idx][i] = [dic['task_id'] for dic in topk.get_top_k()] + + module.enable_scale(task_id = task_idx, space = [torch.tensor(self.feature_list_each_tasks[task_id][i]).to(self.device) for task_id in self.final_decision[task_idx][i]]) + print(f'Layer {i} of {task_idx} consider {self.final_decision[task_idx][i]} as trust region') + + if task_idx == 0: + for i, module in enumerate(self.attention_modules): + U, _, _ = torch.linalg.svd(module.cur_matrix) + module.lora_A_k.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.lora_A_v.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.reset_input_matrix() + else: + + for i, module in enumerate(self.attention_modules): + assert self.project_type[i] == 'remove' or self.project_type[i] == 'retain' + + cur_matrix = module.cur_matrix + feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) + + if self.project_type[i] == 'remove': + cur_matrix = cur_matrix - feature_mat @ cur_matrix + else: + cur_matrix = feature_mat @ cur_matrix + + U, _, _ = np.linalg.svd(cur_matrix.cpu().numpy(), full_matrices = False) + U = torch.tensor(U).to(self.device) + + module.lora_A_k.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.lora_A_v.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.reset_input_matrix() + + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if f"classifier_pool.{task_idx}" in name or f"lora_B" in name or f"scale_param.{task_idx}" in name: + param.requires_grad_(True) + unfrezeed_params = [name for name, param in self._network.named_parameters() if param.requires_grad] + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Called after each task before final testing, it is used to perform preliminary operations on the mapping matrix to facilitate the update of lora_a layer in the next round of before_task + ''' + + [module.merge_weight() for module in self.attention_modules] + + self._update_feature(task_idx, train_loader, test_loaders) + + self._update_input_matrix(train_loader, test_loaders[0].dataset.trfms) + + threshold = self.lamb + + for i, module in enumerate(self.attention_modules): + + activation = module.cur_matrix + U, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_ratio = (S**2)/(S**2).sum() + + r = max(np.sum(np.cumsum(sval_ratio) < threshold), 1) + + # DEBUG, REMOVE + tnsr = torch.Tensor(U[:, :r]) + module.save_space(task_idx, tnsr) + + target_r = max([r] + [module.saved_space[ttt][0].shape[1] for ttt in range(task_idx)]) + + for ttt in range(task_idx + 1): + # 对齐 + saved = module.saved_space[ttt][0] + + if saved.shape[1] < target_r: + new = torch.zeros((768, target_r)) + new[:, :saved.shape[1]] = saved + module.saved_space[ttt][0] = new + + module.reset_input_matrix() + + @torch.no_grad() + def _update_feature(self, task_idx, train_loader, test_loaders): + ''' + Update feature lists and the corresponding type + ''' + + self._update_input_matrix(train_loader, test_loaders[0].dataset.trfms) + + threshold = (self.lame - self.lamb)*task_idx/self.task_num + self.lamb + + if task_idx == 0: + for i, attention_module in enumerate(self.attention_modules): + activation = attention_module.cur_matrix + + U, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_ratio = (S**2)/(S**2).sum() + r = max(np.sum(np.cumsum(sval_ratio) < threshold), 1) + assert r < activation.shape[0]/2 + + self.feature_list_each_tasks[task_idx][i] = U[:, :r] + self.feature_list.append(U[:, :r]) + self.project_type.append('remove') + + attention_module.reset_input_matrix() + else: + for i, attention_module in enumerate(self.attention_modules): + + activation = attention_module.cur_matrix + _, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_total = (S**2).sum() + + if self.project_type[i] == 'remove': + + act_hat = activation - torch.Tensor(self.feature_list[i] @ self.feature_list[i].transpose()) @ activation + U, S, _ = np.linalg.svd(act_hat, full_matrices = False) + sigma = S**2 + + delta = (torch.tensor(self.feature_list[i]).T @ activation @ activation.T @ torch.tensor(self.feature_list[i])).diagonal() + + stack = np.hstack((delta, sigma)) + stack_index = np.argsort(stack)[::-1] # the index of each element in descending sorted array + stack = np.sort(stack)[::-1] # descending sorted array + + if threshold * sval_total <= 0: + r = 0 + else: + r = min(np.sum(np.cumsum(stack) < threshold * sval_total) + 1, activation.shape[0]) + + Ui = np.hstack((self.feature_list[i], U)) + sel_each = stack_index[:r] + sel_overall = sel_each[sel_each >= len(delta)] # without overlap + + self.feature_list[i] = np.hstack((self.feature_list[i], Ui[:, sel_overall])) + self.feature_list_each_tasks[task_idx][i] = Ui[:, sel_each] + + if sel_overall.shape[0] == 0: + print(f'Skip Updating Space for layer: {i+1}') + + else: + act_hat = Torch.Tensor(self.feature_list[i] @ self.feature_list[i].transpose()) @ activation + U,S,_ = np.linalg.svd(act_hat, full_matrices = False) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = sval_hat/sval_total + + if accumulated_sval < 1 - threshold: + print (f'Skip Updating Space for layer: {i+1}') + else: + r = np.sum(accumulated_sval - np.cumsum(sval_ratio) >= 1 - threshold) + 1 + act_feature = self.feature_list[i] - U[:,0:r] @ U[:,0:r].T @ self.feature_list[i] + U, _, _ = np.linalg.svd(act_feature) + self.feature_list[i]=U[:,:self.feature_list[i].shape[1]-r] + + attention_module.reset_input_matrix() + + print('-'*40) + print(f'Threshold: {threshold}') + print('-'*40) + for i in range(len(self.feature_list)): + ''' + if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): + feature = self.feature_list[i] + U, S, V = np.linalg.svd(feature) + new_feature = U[:,feature.shape[1]:] + self.feature_list[i] = new_feature + self.project_type[i] = 'retain' + elif self.project_type[i]=='retain': + assert self.feature_list[i].shape[1] <= (self.feature_list[i].shape[0]/2) + ''' + print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) + print('-'*40) + + @torch.no_grad() + def _update_input_matrix(self, train_loader, test_trfms): + + if self.eval_mat: + self._network.eval() + train_trfms = train_loader.dataset.trfms + train_loader.dataset.trfms = test_trfms + + for batch in tqdm(train_loader, desc = "Forwarding to get input matrix"): + self._network.update_input_matrix(batch['image'].to(self.device)) + + if self.eval_mat: + self._network.train() + train_loader.dataset.trfms = train_trfms + + def get_parameters(self, config): + return self._network.parameters() \ No newline at end of file diff --git a/core/model/MInfLoRA3.py b/core/model/MInfLoRA3.py new file mode 100644 index 0000000000000000000000000000000000000000..98e9048bdbc7142ab278285c7ce1419f2eeb4b2b --- /dev/null +++ b/core/model/MInfLoRA3.py @@ -0,0 +1,393 @@ +""" +Code Reference: +https://github.com/liangyanshuo/InfLoRA/blob/main/methods/inflora.py +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np + +from tqdm import tqdm +from .backbone.transformer import MultiHeadAttention_MultiMaskedLoRA3 + +Epsilon = 0.5 + +class TopK: + + ''' + A class to maintain a collection of the top K items based on a specified attribute. + + This class allows for the dynamic addition of items, each represented as a dictionary, + where each dictionary must have a key 'proj_norm' that represents the value used + to determine the ranking. The class keeps track of the top K items with the highest + 'proj_norm' values. + ''' + + def __init__(self, k): + self.k = k + self.top_k_list = [] + + def add(self, dict): + if len(self.top_k_list) < self.k: + self.top_k_list.append(dict) + elif dict['proj_norm'] > min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: + self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['proj_norm'])) + self.top_k_list.append(dict) + elif dict['proj_norm'] == min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm'] and \ + dict['proj_norm'] == max(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: + self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['task_id'])) + self.top_k_list.append(dict) + + def get_top_k(self): + return self.top_k_list + +class SiNet(nn.Module): + def __init__(self, backbone, **kwargs): + super().__init__() + + self._cur_task_id = -1 + self.backbone = backbone + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + + self.classifier_pool = nn.ModuleList([ + nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + + [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)]) + + for name, module in self.backbone.named_modules(): + if 'transformer' in name and 'blocks' not in name: + self.transformer_module = module + + def update_fc(self): + self._cur_task_id += 1 + + def forward(self, x, expert_id, inference = False): + logits = [] + features = self.backbone(x, expert_id = expert_id) + + if inference: + + # Bayesian + for i, prompts in enumerate(self.classifier_pool[:self._cur_task_id + 1]): + # No Masking + logits.append(prompts(features)) + + logits = torch.cat(logits, dim=1) + + return logits + + else: + logits.append(self.classifier_pool[self._cur_task_id](features)) + return torch.cat(logits, dim=1) + + def update_input_matrix(self, x): + self.backbone(x, expert_id = -1, get_input_matrix = True) + +class MInfLoRA3(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.device = device + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.task_num = kwargs["task_num"] + self.lame = kwargs["lame"] + self.lamb = kwargs["lamb"] + self.embd_dim = kwargs["embd_dim"] + self.eval_mat = kwargs['eval_mat'] + + self._known_classes = 0 + self.feature_list = [] + self.project_type = [] + + self.distributed = torch.distributed.is_initialized() + self.local_rank = torch.distributed.get_rank() if self.distributed else 0 + self._network = SiNet(backbone, **kwargs) + + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_MultiMaskedLoRA3)] + + # TRGP Implementation + self.feature_list_each_tasks = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + self.final_decision = [[np.zeros((1)) for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + self.before_mat = [[0 for _ in range(len(self.attention_modules))] for _ in range(self.task_num)] + + self.experts_distributions = [] + + # Class Alignment Implementation + self._use_class_alignment = kwargs['use_ca'] + self._class_means = None + self._class_covs = None + self._dataset = kwargs['dataset'] + if self._dataset == 'cifar': + self.logit_norm = None + else: + self.logit_norm = 0.1 + + self.lll = [] + + self._network.to(self.device) + + def observe(self, data): + + x, y = data['image'].to(self.device, non_blocking=True), data['label'].to(self.device, non_blocking=True) - self._known_classes + + logits = self._network(x, expert_id = self._network._cur_task_id) + loss = F.cross_entropy(logits, y) + + preds = logits.max(1)[1] + acc = preds.eq(y).sum().item() / y.shape[0] + + return preds, acc, loss + + def inference(self, data, **kwargs): + + task_id = kwargs['task_id'] if 'task_id' in kwargs else -1 + x, y = data['image'].to(self.device, non_blocking=True), data['label'].to(self.device, non_blocking=True) + + logits = self._network(x, expert_id = task_id, inference = True) + preds = logits.max(1)[1] + acc = preds.eq(y).sum().item() / y.shape[0] + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self._network.update_fc() + + [module.init_param() for module in self.attention_modules] + + self._update_input_matrix(train_loader, test_loaders[0].dataset.trfms) + + ''' + for i, module in enumerate(self.attention_modules): + + topk = TopK(1) + + mat = module.cur_matrix.cpu().numpy() + mat_norm = np.linalg.norm(mat) + + for task_id in range(task_idx): + + proj_norm = np.linalg.norm(self.feature_list_each_tasks[task_id][i] @ self.feature_list_each_tasks[task_id][i].T @ mat) + + if proj_norm > Epsilon * mat_norm: + topk.add({'proj_norm':proj_norm, 'task_id': task_id}) + + self.final_decision[task_idx][i] = [dic['task_id'] for dic in topk.get_top_k()] + print(f'Layer {i} of {task_idx} consider {self.final_decision[task_idx][i]} as trust region') + ''' + + if self.local_rank == 0: + + if task_idx == 0: + for i, module in enumerate(self.attention_modules): + + U, _, _ = torch.linalg.svd(module.cur_matrix) + U = torch.Tensor(U).to(self.device) + + module.lora_A_k.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + module.lora_A_v.weight.data.copy_(U[:,:module.lora_rank].T/math.sqrt(3)) + else: + for i, module in enumerate(self.attention_modules): + assert self.project_type[i] == 'remove' or self.project_type[i] == 'retain' + + #tr = self.final_decision[task_idx][i][0] + #feature_mat = torch.Tensor(self.feature_list_each_tasks[tr][i] @ self.feature_list_each_tasks[tr][i].T).to(self.device) + + #target_shape = max(70, self.feature_list[i].shape[1]) # constant 50 and whole feature_list and no QQ^T, get best result for now + target_shape = 768 + + # either /math.sqrt(3) or no /math.sqrt(3) is bad + + cur_matrix = module.cur_matrix.to(self.device) + feature_mat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T).to(self.device) + + q_weight, k_weight, v_weight = module.qkv.weight.chunk(3, dim=0) + kk = feature_mat - k_weight.data @ feature_mat + vv = feature_mat - v_weight.data @ feature_mat + + U, _, _ = np.linalg.svd(kk.cpu().numpy(), full_matrices = False) + U = torch.Tensor(U).to(self.device) + module.space_k[task_idx] = U[:, :target_shape].T/math.sqrt(3) + + U, _, _ = np.linalg.svd(vv.cpu().numpy(), full_matrices = False) + U = torch.Tensor(U).to(self.device) + module.space_v[task_idx] = U[:, :target_shape].T/math.sqrt(3) + + if self.project_type[i] == 'remove': + cur_matrix = cur_matrix - feature_mat @ cur_matrix + else: + cur_matrix = feature_mat @ cur_matrix + + U, _, _ = np.linalg.svd(cur_matrix.cpu().numpy(), full_matrices = False) + U = torch.Tensor(U).to(self.device) + + module.lora_A_k.weight.data.copy_(U[:, :module.lora_rank].T/math.sqrt(3)) + module.lora_A_v.weight.data.copy_(U[:, :module.lora_rank].T/math.sqrt(3)) + + # Initilize space_k and space_v before sync + if self.local_rank != 0 and task_idx != 0: + for module in self.attention_modules: + module.space_k[task_idx] = torch.empty((50, self.embd_dim)).to(self.device) + module.space_v[task_idx] = torch.empty((50, self.embd_dim)).to(self.device) + + if self.distributed and task_idx != 0: + dist.barrier() + for module in self.attention_modules: + dist.broadcast(module.lora_A_k.weight.data, 0) + dist.broadcast(module.lora_A_v.weight.data, 0) + dist.broadcast(module.space_k[task_idx].contiguous(), 0) + dist.broadcast(module.space_v[task_idx].contiguous(), 0) + + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if f"classifier_pool.{task_idx}" in name or \ + f"lora_B_k_list.{task_idx}" in name or \ + f"lora_B_v_list.{task_idx}" in name or \ + f"scale_param.{task_idx}" in name: + param.requires_grad_(True) + + if self.local_rank == 0: + for name, param in self._network.named_parameters(): + if param.requires_grad: + print(name) + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Called after each task before final testing, it is used to perform preliminary operations on the mapping matrix to facilitate the update of lora_a layer in the next round of before_task + ''' + + self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + [module.merge_weight() for module in self.attention_modules] + + self._update_feature(task_idx, train_loader, test_loaders) + + @torch.no_grad() + def _update_feature(self, task_idx, train_loader, test_loaders): + ''' + Update feature lists and the corresponding type + ''' + + self._update_input_matrix(train_loader, test_loaders[0].dataset.trfms) + + if self.local_rank == 0: + + threshold = (self.lame - self.lamb)*task_idx/self.task_num + self.lamb + + if task_idx == 0: + for i, module in enumerate(self.attention_modules): + + activation = module.cur_matrix + + U, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_ratio = (S**2)/(S**2).sum() + r = max(np.sum(np.cumsum(sval_ratio) < threshold), 1) + assert r < activation.shape[0]/2 + + self.feature_list_each_tasks[task_idx][i] = U[:, :r] + self.feature_list.append(U[:, :r]) + self.project_type.append('remove') + + else: + for i, module in enumerate(self.attention_modules): + + activation = module.cur_matrix + _, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_total = (S**2).sum() + + if self.project_type[i] == 'remove': + + act_hat = activation - torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) @ activation + U, S, _ = np.linalg.svd(act_hat, full_matrices = False) + sigma = S**2 + + delta = (torch.Tensor(self.feature_list[i]).T @ activation @ activation.T @ torch.Tensor(self.feature_list[i])).diagonal() + + stack = np.hstack((delta, sigma)) + stack_index = np.argsort(stack)[::-1] # the index of each element in descending sorted array + stack = np.sort(stack)[::-1] # descending sorted array + + if threshold * sval_total <= 0: + r = 0 + else: + r = min(np.sum(np.cumsum(stack) < threshold * sval_total) + 1, activation.shape[0]) + + Ui = np.hstack((self.feature_list[i], U)) + sel_each = stack_index[:r] + sel_overall = sel_each[sel_each >= len(delta)] # without overlap + + self.feature_list[i] = np.hstack((self.feature_list[i], Ui[:, sel_overall])) + self.feature_list_each_tasks[task_idx][i] = Ui[:, sel_each] + + if sel_overall.shape[0] == 0: + print(f'Skip Updating Space for layer: {i+1}') + + else: + act_hat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) @ activation + U,S,_ = np.linalg.svd(act_hat, full_matrices = False) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = sval_hat/sval_total + + if accumulated_sval < 1 - threshold: + print (f'Skip Updating Space for layer: {i+1}') + else: + r = np.sum(accumulated_sval - np.cumsum(sval_ratio) >= 1 - threshold) + 1 + act_feature = self.feature_list[i] - U[:,0:r] @ U[:,0:r].T @ self.feature_list[i] + U, _, _ = np.linalg.svd(act_feature) + self.feature_list[i]=U[:,:self.feature_list[i].shape[1]-r] + + print('-'*40) + print(f'Threshold: {threshold}') + print('-'*40) + for i in range(len(self.feature_list)): + ''' + if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): + feature = self.feature_list[i] + U, S, V = np.linalg.svd(feature) + new_feature = U[:,feature.shape[1]:] + self.feature_list[i] = new_feature + self.project_type[i] = 'retain' + elif self.project_type[i]=='retain': + assert self.feature_list[i].shape[1] <= (self.feature_list[i].shape[0]/2) + ''' + print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) + print('-'*40) + + @torch.no_grad() + def _update_input_matrix(self, train_loader, test_trfms): + + if self.eval_mat: + self._network.eval() + train_trfms = train_loader.dataset.trfms + train_loader.dataset.trfms = test_trfms + + for module in self.attention_modules: + module.reset_input_matrix() + + for batch in tqdm(train_loader, desc="Forwarding to get input matrix", disable=(self.local_rank != 0)): + self._network.update_input_matrix(batch['image'].to(self.device, non_blocking=True)) + + if self.distributed: # Combine input matrix across all GPUs + for module in self.attention_modules: + n_cur_matrix = torch.tensor(module.n_cur_matrix).to(self.device) + cur_matrix = (module.cur_matrix * module.n_cur_matrix).to(self.device) + + dist.all_reduce(cur_matrix, op=dist.ReduceOp.SUM) + dist.all_reduce(n_cur_matrix, op=dist.ReduceOp.SUM) + + module.n_cur_matrix = n_cur_matrix.item() + module.cur_matrix = cur_matrix.cpu() / module.n_cur_matrix + + if self.eval_mat: + self._network.train() + train_loader.dataset.trfms = train_trfms + + def get_parameters(self, config): + return self._network.parameters() \ No newline at end of file diff --git a/core/model/__init__.py b/core/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca88a3bbfb06a6c6968f0d7f1a0aaed103b0fb9b --- /dev/null +++ b/core/model/__init__.py @@ -0,0 +1,34 @@ +from core.model.backbone import * +from core.model.buffer import * + +from .finetune import Finetune +from .icarl import ICarl +from .lucir import LUCIR +from .lwf import LWF +from .wa import WA +from .bic import bic +from .ewc import EWC +from .ocm import OCM +from .eraml import ERAML +from .erace import ERACE +from .der import DER +from .dualprompt import DualPrompt +from .l2p import L2P +from .codaprompt import CodaPrompt +from .praka import PRAKA +from .ranpac import RanPAC +from .trgp import TRGP +from .InfLoRA import InfLoRA +from .InfLoRA_opt import InfLoRA_OPT +from .MInfLoRA import MInfLoRA +from .MInfLoRA2 import MInfLoRA2 +from .MInfLoRA3 import MInfLoRA3 +from .moe_adapter4cl import MOE_ADAPTER4CL +from .dmnsp import DMNSP +from .rapf import RAPF +from .gpm import GPM +from .api import API +from .dap import DAP +from .sd_lora import SD_LoRA +from .lora_sub import LoRAsub_DRS +from .cl_lora import CL_LoRA \ No newline at end of file diff --git a/core/model/api.py b/core/model/api.py new file mode 100644 index 0000000000000000000000000000000000000000..6839905139d269b6eec4a7c9d65be889e078b48f --- /dev/null +++ b/core/model/api.py @@ -0,0 +1,339 @@ +""" +@inproceedings{liang2023adaptive, + title={Adaptive Plasticity Improvement for Continual Learning}, + author={Liang, Yan-Shuo and Li, Wu-Jun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={7816--7825}, + year={2023} +} + +Code Reference: +https://github.com/liangyanshuo/Adaptive-Plasticity-Improvement-for-Continual-Learning +""" + +import math +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import numpy as np + +from .backbone.alexnet import Conv2d_API, Linear_API, AlexNet_API + +batch_list = [2*12, 100, 100] +ksize = [4, 3, 2, 1, 1] # kernel size of each conv layer +channels = [3, 64, 128, 1024, 2048] +conv_output_size = [29, 12, 5] # output size of each conv layer + +class Network(nn.Module): + + def __init__(self, backbone, **kwargs): + + super().__init__() + self.backbone = backbone + + self.classifiers = nn.ModuleList([ + nn.Linear(backbone.feat_dim, kwargs['init_cls_num'], bias = False)] + + [nn.Linear(backbone.feat_dim, kwargs['inc_cls_num'], bias = False) for _ in range(kwargs['task_num'] - 1)] + ) + + def forward(self, data, t, compute_input_matrix = False): + + feat = self.backbone(data, t, compute_input_matrix) + return [fc(feat) for fc in self.classifiers] + +class API(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + self.network = Network(backbone, **kwargs) + self.device = device + + self.task_num = kwargs["task_num"] + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self._known_classes = 0 + + self.feature_list = [] + self.feature_mat = [] + self.project_type = [] + self.step = 0.5 + self.K = 10 + + self.layers = [module for module in self.network.modules() if isinstance(module, Conv2d_API) or isinstance(module, Linear_API)] + + self.network.to(self.device) + + def observe(self, data, stage=0): + + # Stage=0 : The main train + # Stage=1 : The FIRst train + # Stage=2 : The Second train + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + if stage == 1 or stage == 2: # evaluate should only in stage==2 + logits = self.network(x, self.cur_task - 1) + else: + logits = self.network(x, self.cur_task) + + loss = F.cross_entropy(logits[self.cur_task], y) + + preds = logits[self.cur_task].max(1)[1] + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + loss.backward() + + per_layer_norm = [layer.weight.grad.norm(p=2) for layer in self.layers] + + if self.cur_task > 0: + for i, layer in enumerate(self.layers): + sz = layer.weight.grad.data.size(0) + expand = self.expand[i][-1] + assert expand == self.expand[i][self.cur_task-1] + if self.project_type[i] == 'retain': + layer.weight.grad.data[:, :expand] = (layer.weight.grad.data[:,:expand].view(sz, -1) @ self.feature_mat[i]).view(layer.weight[:, :expand].size()) + elif self.project_type[i] == 'remove': + layer.weight.grad.data[:, :expand] = (layer.weight.grad.data[:,:expand].view(sz, -1) - + layer.weight.grad.data[:,:expand].view(sz, -1) @ self.feature_mat[i]).view(layer.weight[:, :expand].size()) + + for i, layer in enumerate(self.layers): + self.per_layer_retain[i] += layer.weight.grad.norm(p=2)/per_layer_norm[i] + + if stage == 1: + self.optimizer_stage1.step() + else: + # either stage 0 or stage 2, stage 0 call optimizer.step() and stage 2 do nothing + return preds, acc, loss + + def inference(self, data, task_id=-1): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + # Task-Aware (Task-Incremetanl Scenario) + if task_id > -1: + + if task_id == 0: + bias_classes = 0 + elif task_id == 1: + bias_classes = self.init_cls_num + else: + bias_classes = self.init_cls_num + (task_id - 1) * self.inc_cls_num + + logits = self.network(x, task_id) + preds = logits[task_id].max(1)[1] + bias_classes + + # Task-Agnostic (Class-Incremetanl Scenario) + else: + + logits = torch.cat(self.network(x, self.cur_task), dim=-1) + preds = logits.max(1)[1] + + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self.per_layer_retain = [0., 0., 0., 0., 0.] # depends on backbone, if resnet then differerent + self.cur_task = task_idx + + if task_idx == 1: + self._known_classes += self.init_cls_num + elif task_idx > 1: + self._known_classes += self.inc_cls_num + + if task_idx > 0: + + # bn's parameters are only learned for the first task + for name, param in self.network.named_parameters(): + param.requires_grad_(True) + if 'bn' in name: + param.requires_grad_(False) + + for ep in range(5): + for batch in train_loader: + self.optimizer_stage1.zero_grad() + self.observe(batch, stage = 1) + + # TODO: early stop + + for batch in train_loader: + self.observe(batch, stage = 2) + + num_iter = len(train_loader) * (5 + 1) + self.per_layer_retain = [(retain/num_iter).item() for retain in self.per_layer_retain] + + mat_list = self.get_mat(task_idx - 1, train_loader) + + for i, mat in enumerate(mat_list): + sz = mat.shape[-1] + mat_list[i] = np.linalg.norm( + mat[:channels[i] * ksize[i] * ksize[i]].T.reshape(sz, channels[i], ksize[i], ksize[i]), ord=2, axis=(2,3) + ).T + + sizes, ws = [], [] + for i, layer in enumerate(self.layers): + + U, _, _ = np.linalg.svd(mat_list[i], full_matrices=False) + + expand_dim = max((self.step - self.per_layer_retain[i]) * self.K, 0) + size = max(min(math.ceil(expand_dim), channels[i]), 0) + + sizes.append(size) + ws.append(torch.Tensor(U[:, :size]).to(self.device)) + + self.network.backbone.expand(sizes, ws) + self.network.to(self.device) + + self.layers = [module for module in self.network.modules() if isinstance(module, Conv2d_API) or isinstance(module, Linear_API)] + + # not include the additional w + self.optimizer_stage1 = optim.SGD(self.get_parameters(additional=False), lr=0.01) + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + mat_list = self.get_mat(task_idx, train_loader) + + self.expand = [] # self.expand[i][j] is the expanded size of i-th layer in j-th task + for i, layer in enumerate(self.layers): + self.expand.append(np.cumsum([0] + layer.expand)) + self.expand[i] += channels[i] + + for i, (feature, layer) in enumerate(zip(self.feature_list, self.layers)): + assert task_idx > 0 + if isinstance(layer, Conv2d_API): + sz = layer.expand[task_idx - 1] * ksize[i] * ksize[i] + elif isinstance(layer, Linear_API): + sz = layer.expand[task_idx - 1] + else: + raise NotImplementedError + + if sz: + if self.project_type[i] == 'retain': + self.feature_list[i] = np.vstack((self.feature_list[i],np.zeros((sz, self.feature_list[i].shape[1])))) + self.feature_list[i] = np.hstack((self.feature_list[i],np.zeros((self.feature_list[i].shape[0], sz)))) + self.feature_list[i][-sz:,-sz:] = np.eye(sz) + elif self.project_type[i] == 'remove': + self.feature_list[i] = np.vstack((self.feature_list[i],np.zeros((sz,self.feature_list[i].shape[1])))) + else: + raise Exception('Wrong project type') + + threshold = 0.97 + task_idx * 0.03 / self.task_num + + # get the space for each layer + if task_idx == 0: + for i, activation in enumerate(mat_list): + + U, S, _ = np.linalg.svd(activation, full_matrices = False) + # criteria (Eq-5) + sval_total = (S**2).sum() + sval_ratio = (S**2)/sval_total + r = np.sum(np.cumsum(sval_ratio) < threshold) + + if r < activation.shape[0]/2: + self.feature_list.append(U[:, :r]) + self.project_type.append('remove') + else: + self.feature_list.append(U[:, r:]) + self.project_type.append('retain') + + else: + for i, activation in enumerate(mat_list): + + _, S, _ = np.linalg.svd(activation, full_matrices=False) + sval_total = (S**2).sum() + + if self.project_type[i] == 'remove': + + act_hat = activation - self.feature_list[i] @ self.feature_list[i].T @ activation + U, S, _ = np.linalg.svd(act_hat, full_matrices = False) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = (sval_total-sval_hat)/sval_total + + if accumulated_sval >= threshold: + print (f'Skip Updating DualGPM for layer: {i+1}') + else: + r = np.sum(np.cumsum(sval_ratio) + accumulated_sval < threshold) + 1 + Ui = np.hstack((self.feature_list[i], U[:, :r])) + self.feature_list[i] = Ui[:, :min(Ui.shape[0], Ui.shape[1])] + + else: + act_hat = torch.Tensor(self.feature_list[i] @ self.feature_list[i].T) @ activation + U,S,_ = np.linalg.svd(act_hat, full_matrices = False) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = sval_hat/sval_total + + if accumulated_sval < 1 - threshold: + print (f'Skip Updating Space for layer: {i+1}') + else: + r = np.sum(accumulated_sval - np.cumsum(sval_ratio) >= 1 - threshold) + 1 + act_feature = self.feature_list[i] - U[:, :r] @ U[:, :r].T @ self.feature_list[i] + U, _, _ = np.linalg.svd(act_feature) + self.feature_list[i]=U[:,:self.feature_list[i].shape[1]-r] + + print('-'*40) + print('Gradient Constraints Summary') + print('-'*40) + for i in range(len(self.feature_list)): + if self.project_type[i]=='remove' and (self.feature_list[i].shape[1] > (self.feature_list[i].shape[0]/2)): + feature = self.feature_list[i] + U, _, _ = np.linalg.svd(feature) + new_feature = U[:,feature.shape[1]:] + self.feature_list[i] = new_feature + self.project_type[i] = 'retain' + print ('Layer {} : {}/{} type {}'.format(i+1,self.feature_list[i].shape[1], self.feature_list[i].shape[0], self.project_type[i])) + print('-'*40) + + # Projection Matrix Precomputation + self.feature_mat = [] + for feature, proj_type in zip(self.feature_list, self.project_type): + if proj_type == 'remove': + self.feature_mat.append(torch.Tensor(feature @ feature.T).to(self.device)) + elif proj_type == 'retain': + self.feature_mat.append(torch.zeros(feature.shape[0], feature.shape[0]).to(self.device)) + + def get_mat(self, t, train_loader): + + x = torch.cat([b['image'] for b in train_loader], dim = 0).to(self.device) + + # hardcoded, choose 125 input from it + indices = torch.randperm(x.size(0)) + selected_indices = indices[:125] + x = x[selected_indices] + + self.network.eval() + self.network(x, t = t, compute_input_matrix = True) + + mat_list = [] # representation (activation) of each layer + for i, module in enumerate(self.layers): + + if isinstance(module, Conv2d_API): + bsz, ksz, s, inc = batch_list[i], ksize[i], conv_output_size[i], module.in_channels + + mat = np.zeros((ksz * ksz * inc, s * s * bsz)) + act = module.input_matrix.detach().cpu().numpy() + + k = 0 + for kk in range(bsz): + for ii in range(s): + for jj in range(s): + mat[:,k]=act[kk, :, ii:ksz+ii, jj:ksz+jj].reshape(-1) + k += 1 + + mat_list.append(mat) + elif isinstance(module, Linear_API): + mat_list.append(module.input_matrix.detach().cpu().numpy().T) + + return mat_list + + def get_parameters(self, config=None, additional=True): + if additional: + return self.network.parameters() + else: + return [param for name, param in self.network.named_parameters() if 'extra_ws' not in name] + \ No newline at end of file diff --git a/core/model/backbone/SiNet.py b/core/model/backbone/SiNet.py new file mode 100644 index 0000000000000000000000000000000000000000..595481ae518e1b7e06506c1e63278529dcc939a3 --- /dev/null +++ b/core/model/backbone/SiNet.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +import copy + +from .vit_inflora import VisionTransformer, PatchEmbed, Block, resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn + +class ViT_lora_co(VisionTransformer): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, n_tasks=10, rank=64): + + super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, + embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size, + drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values, + embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn, n_tasks=n_tasks, rank=rank) + + def forward(self, x, task_id, register_blk=-1, get_feat=False, get_cur_feat=False): + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + x = x + self.pos_embed[:, :x.size(1), :] + x = self.pos_drop(x) + + prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device) + for i, blk in enumerate(self.blocks): + x = blk(x, task_id, register_blk == i, + get_feat=get_feat, get_cur_feat=get_cur_feat) + + x = self.norm(x) + + return x, prompt_loss + + +def _create_vision_transformer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError( + 'features_only not implemented for Vision Transformer models.') + + # NOTE this extra code to support handling of repr size for in21k pretrained models + # pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant) + default_num_classes = pretrained_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + repr_size = None + + model = build_model_with_cfg( + ViT_lora_co, variant, pretrained, + pretrained_cfg=pretrained_cfg, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in pretrained_cfg['url'], + **kwargs) + return model + + +class SiNet_vit(nn.Module): + + def __init__(self, **args): + ''' + args is a dictionary with the required arguments. + image_encoder is defined in vit_inflora. + class_num is the number of initial class. + ''' + super(SiNet_vit, self).__init__() + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, + num_heads=12, n_tasks=args["total_sessions"], rank=args["rank"]) + self.image_encoder = _create_vision_transformer( + 'vit_base_patch16_224_in21k', pretrained=True, **model_kwargs) + self.class_num = 1 + self.class_num = args["init_cls"] + self.classifier_pool = nn.ModuleList([ + nn.Linear(args["embd_dim"], self.class_num, bias=True) + for i in range(args["total_sessions"]) + ]) + self.classifier_pool_backup = nn.ModuleList([ + nn.Linear(args["embd_dim"], self.class_num, bias=True) + for i in range(args["total_sessions"]) + ]) + self.numtask = 0 + + @property + def feature_dim(self): + return self.image_encoder.out_dim + + def extract_vector(self, image, task=None): + if task == None: + image_features, _ = self.image_encoder(image, self.numtask-1) + else: + image_features, _ = self.image_encoder(image, task) + image_features = image_features[:, 0, :] + return image_features + + def forward(self, image, get_feat=False, get_cur_feat=False, fc_only=False): + """ + return the output of fully connected layer. + """ + if fc_only: + fc_outs = [] + for ti in range(self.numtask): + fc_out = self.classifier_pool[ti](image) + fc_outs.append(fc_out) + return torch.cat(fc_outs, dim=1) + + logits = [] + image_features, prompt_loss = self.image_encoder( + image, task_id=self.numtask-1, get_feat=get_feat, get_cur_feat=get_cur_feat) + image_features = image_features[:, 0, :] + image_features = image_features.view(image_features.size(0), -1) + for prompts in [self.classifier_pool[self.numtask-1]]: + logits.append(prompts(image_features)) + + return { + 'logits': torch.cat(logits, dim=1), + 'features': image_features, + 'prompt_loss': prompt_loss + } + + def interface(self, image): + image_features, _ = self.image_encoder(image, task_id=self.numtask-1) + + image_features = image_features[:, 0, :] + image_features = image_features.view(image_features.size(0), -1) + + logits = [] + for prompt in self.classifier_pool[:self.numtask]: + logits.append(prompt(image_features)) + + logits = torch.cat(logits, 1) + return logits + + def update_fc(self, nb_classes): + """ + update the number of tasks. + """ + self.numtask += 1 + + def classifier_backup(self, task_id): + self.classifier_pool_backup[task_id].load_state_dict( + self.classifier_pool[task_id].state_dict()) + + def classifier_recall(self): + self.classifier_pool.load_state_dict(self.old_state_dict) + + def copy(self): + return copy.deepcopy(self) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + + return self diff --git a/core/model/backbone/__init__.py b/core/model/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0036f40dcecafc05ebd4339d543d17952a9cab12 --- /dev/null +++ b/core/model/backbone/__init__.py @@ -0,0 +1,30 @@ +from .resnet import * +from .vit import vit_pt_imnet +from .vit import vit_pt_imnet_in21k_adapter +from .vit import vit_cl_lora +from .vit_dap import vit_pt_imnet_dap + +from .SiNet import SiNet_vit + +from .resnet_cbam import * +from .alexnet import AlexNet_TRGP, AlexNet_API +from .clip import clip + +def get_backbone(config): + """ + Get the backbone according to the config dict. + + Args: + config: The config dict. + + Returns: The backbone module. + """ + + kwargs = dict() + kwargs.update(config['backbone']['kwargs']) + try: + emb_func = eval(config["backbone"]['name'])(**kwargs) + except NameError: + raise ("{} is not implemented".format(config["backbone"]['name'])) + + return emb_func diff --git a/core/model/backbone/alexnet.py b/core/model/backbone/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..012bf84b506a35f1267345a370c668755f314567 --- /dev/null +++ b/core/model/backbone/alexnet.py @@ -0,0 +1,304 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Conv2d_TRGP(nn.Conv2d): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + dilation=1, + groups=1, + bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) + + # define the scale V + size = self.weight.shape[1] * self.weight.shape[2] * self.weight.shape[3] + self.identity_matrix = torch.eye(size, device = self.weight.device) + + self.space = [] + self.scale_param = nn.ParameterList() + + def enable_scale(self, space): + self.space = space + self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix).to(self.weight.device) for _ in self.space]) + + def disable_scale(self): + + self.space = [] + self.scale_param = nn.ParameterList() + + def forward(self, input, compute_input_matrix = False): + + # this should be only called once for each task + if compute_input_matrix: + self.input_matrix = input + + sz = self.weight.shape[0] + + masked_weight = self.weight + + for scale, space in zip(self.scale_param, self.space): + + cropped_scale = scale[:space.size(1), :space.size(1)] + cropped_identity_matrix = self.identity_matrix[:space.shape[1], :space.shape[1]].to(self.weight.device) + + #masked_weight = masked_weight + (self.weight.view(sz, -1) @ space @ (cropped_scale - cropped_identity_matrix) @ space.T).\ + # view(self.weight.shape) + + masked_weight = masked_weight + (masked_weight.view(sz, -1) @ space @ (cropped_scale - cropped_identity_matrix) @ space.T).\ + view(masked_weight.shape) + + + return F.conv2d(input, masked_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class Linear_TRGP(nn.Linear): + + def __init__(self, in_features, out_features, bias = True): + super().__init__(in_features, out_features, bias = bias) + + # define the scale Q + self.identity_matrix = torch.eye(self.weight.shape[1], device = self.weight.device) + + self.space = [] + self.scale_param = nn.ParameterList() + + def enable_scale(self, space): + self.space = space + self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix).to(self.weight.device) for _ in self.space]) + + def disable_scale(self): + + self.space = [] + self.scale_param = nn.ParameterList() + + def forward(self, input, compute_input_matrix = False): + + # this should be only called once for each task + if compute_input_matrix: + self.input_matrix = input # save input_matrix here + + masked_weight = self.weight + for scale, space in zip(self.scale_param, self.space): + + cropped_scale = scale[:space.shape[1], :space.shape[1]] + cropped_identity_matrix = self.identity_matrix[:space.shape[1], :space.shape[1]].to(self.weight.device) + + masked_weight = masked_weight + masked_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T # ? + + return F.linear(input, masked_weight, self.bias) + +class AlexNet_TRGP(nn.Module): + + def __init__(self, dropout_rate_1 = 0.2, dropout_rate_2 = 0.5, **kwargs): + + super().__init__() + + self.conv1 = Conv2d_TRGP(in_channels = 3, out_channels = 64, kernel_size = 4, bias = False) + self.bn1 = nn.BatchNorm2d(64, track_running_stats = False) + + self.conv2 = Conv2d_TRGP(in_channels = 64, out_channels = 128, kernel_size = 3, bias = False) + self.bn2 = nn.BatchNorm2d(128, track_running_stats = False) + + self.conv3 = Conv2d_TRGP(in_channels = 128, out_channels = 256, kernel_size = 2, bias = False) + self.bn3 = nn.BatchNorm2d(256, track_running_stats = False) + + self.fc1 = Linear_TRGP(in_features = 1024, out_features = 2048, bias = False) + self.bn4 = nn.BatchNorm1d(2048, track_running_stats = False) + + self.fc2 = Linear_TRGP(in_features = 2048, out_features = 2048, bias=False) + self.bn5 = nn.BatchNorm1d(2048, track_running_stats = False) + + self.feat_dim = 2048 # final feature's dim + + + # common use + self.relu = nn.ReLU() + self.dropout1 = nn.Dropout(dropout_rate_1) + self.dropout2 = nn.Dropout(dropout_rate_2) + self.maxpool = nn.MaxPool2d(kernel_size = 2) + + def forward(self, x, compute_input_matrix): + + x = self.conv1(x, compute_input_matrix) + x = self.bn1(x) + x = self.relu(x) + x = self.dropout1(x) + x = self.maxpool(x) + + x = self.conv2(x, compute_input_matrix) + x = self.bn2(x) + x = self.relu(x) + x = self.dropout1(x) + x = self.maxpool(x) + + x = self.conv3(x, compute_input_matrix) + x = self.bn3(x) + x = self.relu(x) + x = self.dropout2(x) + x = self.maxpool(x) + + x = x.view(x.size(0), -1) + + x = self.fc1(x, compute_input_matrix) + x = self.bn4(x) + x = self.relu(x) + x = self.dropout2(x) + + x = self.fc2(x, compute_input_matrix) + x = self.bn5(x) + x = self.relu(x) + x = self.dropout2(x) + + return x + +# ----- + +class Conv2d_API(nn.Conv2d): + def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + super().__init__(in_channels, out_channels, kernel_size, stride, padding, bias=bias, dilation=dilation, groups=groups, padding_mode=padding_mode) + + self.extra_ws = nn.ParameterList([]) + self.expand = [] + + def forward(self, input, t, compute_input_matrix = False): + + input = torch.cat([input] + [(input.permute(0, 2, 3, 1) @ self.extra_ws[i]).permute(0, 3, 1, 2) for i in range(t)], dim=1) + + if compute_input_matrix: + self.input_matrix = input + + return F.conv2d(input, self.weight[:, :input.shape[1]], bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + + def duplicate(self, in_channels, extra_w): + dup = Conv2d_API( + self.in_channels + in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.padding_mode + ) + + dup.extra_ws = self.extra_ws + dup.extra_ws.append(extra_w) + dup.expand = self.expand + [in_channels] + + dup.weight.data[:, :self.in_channels].data.copy_(self.weight.data) + + if self.bias is not None: + dup.bias.data[:, :self.in_channels].data.copy_(self.bias.data) + + return dup + +class Linear_API(nn.Linear): + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: + super().__init__(in_features, out_features, bias, device, dtype) + + self.extra_ws = nn.ParameterList([]) + self.expand = [] + + def forward(self, input, t, compute_input_matrix=False): + + input = torch.cat([input] + [input @ self.extra_ws[i] for i in range(t)], dim=1) + + if compute_input_matrix: + self.input_matrix = input + + return F.linear(input, self.weight[:,:input.shape[1]], bias=self.bias) + + def duplicate(self, in_features, extra_w): + dup = Linear_API( + self.in_features + in_features, + self.out_features, + self.bias is not None + ) + + dup.extra_ws = self.extra_ws + dup.extra_ws.append(extra_w) + dup.expand = self.expand + [in_features] + + dup.weight.data[:, :self.in_features].data.copy_(self.weight.data) + + if self.bias is not None: + dup.bias.data[:, :self.in_features].data.copy_(self.bias.data) + + return dup + +class AlexNet_API(nn.Module): + + def __init__(self, dropout_rate_1 = 0.2, dropout_rate_2 = 0.5, **kwargs): + + super().__init__() + + self.select1, self.select2, self.select3, self.select4, self.select5 = [], [], [], [], [] + + self.conv1 = Conv2d_API(in_channels = 3, out_channels = 64, kernel_size = 4, bias = False) + self.bn1 = nn.BatchNorm2d(64, track_running_stats = False) + + self.conv2 = Conv2d_API(in_channels = 64, out_channels = 128, kernel_size = 3, bias = False) + self.bn2 = nn.BatchNorm2d(128, track_running_stats = False) + + self.conv3 = Conv2d_API(in_channels = 128, out_channels = 256, kernel_size = 2, bias = False) + self.bn3 = nn.BatchNorm2d(256, track_running_stats = False) + + self.fc1 = Linear_API(in_features = 1024, out_features = 2048, bias = False) + self.bn4 = nn.BatchNorm1d(2048, track_running_stats = False) + + self.fc2 = Linear_API(in_features = 2048, out_features = 2048, bias=False) + self.bn5 = nn.BatchNorm1d(2048, track_running_stats = False) + + self.feat_dim = 2048 # final feature's dim + + # common use + self.relu = nn.ReLU() + self.dropout1 = nn.Dropout(dropout_rate_1) + self.dropout2 = nn.Dropout(dropout_rate_2) + self.maxpool = nn.MaxPool2d(kernel_size = 2) + + def forward(self, x, t = 0, compute_input_matrix = False): + + x = self.conv1(x, t, compute_input_matrix) + x = self.bn1(x) + x = self.relu(x) + x = self.dropout1(x) + x = self.maxpool(x) + + x = self.conv2(x, t, compute_input_matrix) + x = self.bn2(x) + x = self.relu(x) + x = self.dropout1(x) + x = self.maxpool(x) + + x = self.conv3(x, t, compute_input_matrix) + x = self.bn3(x) + x = self.relu(x) + x = self.dropout2(x) + x = self.maxpool(x) + + x = x.view(x.size(0), -1) + + x = self.fc1(x, t, compute_input_matrix) + x = self.bn4(x) + x = self.relu(x) + x = self.dropout2(x) + + x = self.fc2(x, t, compute_input_matrix) + x = self.bn5(x) + x = self.relu(x) + x = self.dropout2(x) + + return x + + def expand(self, sizes, extra_ws): + self.conv1 = self.conv1.duplicate(sizes[0], extra_ws[0]) + self.conv2 = self.conv2.duplicate(sizes[1], extra_ws[1]) + self.conv3 = self.conv3.duplicate(sizes[2], extra_ws[2]) + self.fc1 = self.fc1.duplicate(sizes[3], extra_ws[3]) + self.fc2 = self.fc2.duplicate(sizes[4], extra_ws[4]) diff --git a/core/model/backbone/clip.py b/core/model/backbone/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..bea4ad6bed986364e6d6d3c8bba3c13c999ae171 --- /dev/null +++ b/core/model/backbone/clip.py @@ -0,0 +1,668 @@ +''' +Adapted from https://github.com/openai/CLIP +''' + +import os +import json +import hashlib +import urllib +import warnings +from collections import Counter, OrderedDict +from typing import Union, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributions.normal import Normal +from tqdm import tqdm + +from .tokenizer.tokenizer import SimpleTokenizer as _Tokenizer +from .petl.adapter import Adapter +from .transformer import LayerNorm, Transformer, VisualTransformer + +class SparseDispatcher(object): + """Helper for implementing a mixture of experts. + The purpose of this class is to create input minibatches for the + experts and to combine the results of the experts to form a unified + output tensor. + There are two functions: + dispatch - take an input Tensor and create input Tensors for each expert. + combine - take output Tensors from each expert and form a combined output + Tensor. Outputs from different experts for the same batch element are + summed together, weighted by the provided "gates". + The class is initialized with a "gates" Tensor, which specifies which + batch elements go to which experts, and the weights to use when combining + the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. + The inputs and outputs are all two-dimensional [batch, depth]. + Caller is responsible for collapsing additional dimensions prior to + calling this class and reshaping the output to the original shape. + See common_layers.reshape_like(). + Example use: + gates: a float32 `Tensor` with shape `[batch_size, num_experts]` + inputs: a float32 `Tensor` with shape `[batch_size, input_size]` + experts: a list of length `num_experts` containing sub-networks. + dispatcher = SparseDispatcher(num_experts, gates) + expert_inputs = dispatcher.dispatch(inputs) + expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] + outputs = dispatcher.combine(expert_outputs) + The preceding code sets the output for a particular example b to: + output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) + This class takes advantage of sparsity in the gate matrix by including in the + `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. + """ + + def __init__(self, num_experts, gates): + """Create a SparseDispatcher.""" + + self._gates = gates + self._num_experts = num_experts + + sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) + + # drop indices + _, self._expert_index = sorted_experts.split(1, dim=1) + # get according batch index for each expert + self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] + # calculate num samples that each expert gets + self._part_sizes = (gates > 0).sum(0).tolist() + # expand gates to match with self._batch_index + gates_exp = gates[self._batch_index.flatten()] + self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) + + def dispatch(self, inp): + """Create one input Tensor for each expert. + The `Tensor` for a expert `i` contains the slices of `inp` corresponding + to the batch elements `b` where `gates[b, i] > 0`. + Args: + inp: a `Tensor` of shape "[batch_size, ]` + Returns: + a list of `num_experts` `Tensor`s with shapes + `[expert_batch_size_i, ]`. + """ + + # assigns samples to experts whose gate is nonzero + + inp_exp = inp[self._batch_index].squeeze(1) + return torch.split(inp_exp, self._part_sizes, dim=0) + + def combine(self, expert_out, multiply_by_gates=True): + """Sum together the expert output, weighted by the gates. + The slice corresponding to a particular batch element `b` is computed + as the sum over all experts `i` of the expert output, weighted by the + corresponding gate values. If `multiply_by_gates` is set to False, the + gate values are ignored. + Args: + expert_out: a list of `num_experts` `Tensor`s, each with shape + `[expert_batch_size_i, ]`. + multiply_by_gates: a boolean + Returns: + a `Tensor` with shape `[batch_size, ]`. + """ + # apply exp to expert outputs, so we are not longer in log space + + stitched = torch.cat(expert_out, 0) + if multiply_by_gates: + stitched = stitched.mul(self._nonzero_gates) # 加权 + + zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), device=stitched.device) + # combine samples that have been processed by the same k experts + + combined = zeros.index_add(0, self._batch_index, stitched.float()) + # add eps to all zero values in order to avoid nans when going back to log space + # back to log space + return combined + + def expert_to_gates(self): + """Gate values corresponding to the examples in the per-expert `Tensor`s. + Returns: + a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` + and shapes `[expert_batch_size_i]` + """ + # split nonzero gates for each expert + return torch.split(self._nonzero_gates, self._part_sizes, dim=0) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + +# ----------------------------- + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + baseline = False, + **kwargs + ): + super().__init__() + + self.baseline = baseline + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + + self.visual = VisualTransformer( + img_size=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + depth=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + text_or_image='image', + **kwargs + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + text_or_image='text', + **kwargs + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + #self.logit_scale = nn.Parameter(torch.tensor(100.0)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + #for block in self.transformer.resblocks: + for block in self.transformer.blocks: + # DEBUG + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + nn.init.normal_(block.attn.qkv.weight, std=attn_std) + nn.init.normal_(block.attn.proj.weight, std=proj_std) + nn.init.normal_(block.mlp.fc1.weight, std=fc_std) + nn.init.normal_(block.mlp.fc2.weight, std=proj_std) + + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, **kwargs): + return self.visual(image.type(self.dtype), **kwargs) + + def encode_text(self, text, **kwargs): + + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, **kwargs) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text, **kwargs): + if image is None: + return self.encode_text(text, **kwargs) + elif text is None: + return self.encode_image(image, **kwargs) + image_features = self.encode_image(image, **kwargs) + text_features = self.encode_text(text, **kwargs) + + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logits_per_image.T + + return image_features, text_features, \ + logits_per_image, logits_per_text + +def build_model(state_dict: dict, **kwargs): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + model = CLIP( + + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, **kwargs + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + # nn.MultiheadAttention is replaced with custom MultiheadAttention, the param name is changed to compatible with Pretrained ViT + key_mapping = { + "attn.in_proj_": "attn.qkv.", + "attn.out_proj.": "attn.proj.", + "mlp.c_fc.": "mlp.fc1.", + "mlp.c_proj.": "mlp.fc2.", + ".resblocks.": ".blocks." + } + + modified_state_dict = {} + for key in state_dict.keys(): + new_key = key + for old_key, mapped_key in key_mapping.items(): + if old_key in new_key: + new_key = new_key.replace(old_key, mapped_key) + + modified_state_dict[new_key] = state_dict[key] + + ''' + original_keys = set(model.state_dict().keys()) + modified_keys = set(modified_state_dict.keys()) + + # Print differences + print("Keys in original state dict but not in modified state dict:") + print('\n'.join(original_keys - modified_keys)) # Original keys that are missing in modified + + print('\n') + print("Keys in modified state dict but not in original state dict:") + print('\n'.join(modified_keys - original_keys)) # Modified keys that are extra in modified + assert 0 + ''' + + + model.load_state_dict(modified_state_dict, strict=False) + for p in model.parameters(): + p.data = p.data.float() + return model.eval() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +} + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + try: + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + except urllib.error.URLError as e: + print(f"Network error: {e.reason}, Manually download the file from {url} and place at {root}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, pretrained=True, **kwargs): + """Load a CLIP model + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + + # TODO: pretrained is never being used + + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {_MODELS.keys()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model(state_dict or model.state_dict(), **kwargs).to(device) + except KeyError: + print('Error') + sd = {k[7:]: v for k,v in state_dict["state_dict"].items()} + model = build_model(sd, **kwargs).to(device) + + if str(device) == "cpu": + model.float() + + return model + + assert 0, 'Part below never test, just set jit to False and call it a day' + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, \ + _transform(model.input_resolution.item(), is_train=True), \ + _transform(model.input_resolution.item(), is_train=False) + +_tokenizer = _Tokenizer() +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: # Truncate + tokens = tokens[:context_length] + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + +def clip(model_name, device, jit = False, pretrained = False, **kwargs): + return load(model_name, device, jit, pretrained, **kwargs) \ No newline at end of file diff --git a/core/model/backbone/petl/__init__.py b/core/model/backbone/petl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/core/model/backbone/petl/adapter.py b/core/model/backbone/petl/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7803115fe7c22b4ab251aed18eac51e224289d51 --- /dev/null +++ b/core/model/backbone/petl/adapter.py @@ -0,0 +1,199 @@ +''' +Adapted from https://github.com/jxhe/unify-parameter-efficient-tuning + +TODO: merge this with vit_adapter, either replace one with another +''' + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from core.model.backbone.alexnet import Linear_TRGP + +class Adapter(nn.Module): + def __init__(self, + d_model=None, + bottleneck=None, + dropout=0.0, + init_option="lora", + adapter_scalar="1.0", + adapter_layernorm_option="in"): + super().__init__() + self.n_embd = d_model if d_model is None else d_model + self.down_size = bottleneck + + #_before + self.adapter_layernorm_option = adapter_layernorm_option + + self.adapter_layer_norm_before = None + if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": + self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd) + + if adapter_scalar == "learnable_scalar": + self.scale = nn.Parameter(torch.ones(1)) + else: + self.scale = float(adapter_scalar) + + self.down_proj = nn.Linear(self.n_embd, 64) + self.non_linear_func = nn.ReLU() + self.up_proj = nn.Linear(self.down_size, self.n_embd) + + self.dropout = dropout + if init_option == "bert": + raise NotImplementedError + elif init_option == "lora": + with torch.no_grad(): + nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) + nn.init.zeros_(self.up_proj.weight) + nn.init.zeros_(self.down_proj.bias) + nn.init.zeros_(self.up_proj.bias) + + def forward(self, x, add_residual=True, residual=None): + + residual = x if residual is None else residual + if self.adapter_layernorm_option == 'in': # none + x = self.adapter_layer_norm_before(x) + + down = self.down_proj(x) + down = self.non_linear_func(down) + down = nn.functional.dropout(down, p=self.dropout, training=self.training) + up = self.up_proj(down) + + up = up * self.scale + + if self.adapter_layernorm_option == 'out': # none + up = self.adapter_layer_norm_before(up) + + if add_residual: + output = up + residual + else: + output = up + + return output + +''' +class MaskedAdapter(nn.Module): + def __init__(self, + d_model=None, + bottleneck=None, + dropout=0.0, + init_option="lora", + adapter_scalar="1.0", + adapter_layernorm_option="in"): + super().__init__() + self.n_embd = d_model if d_model is None else d_model + self.down_size = bottleneck + + #_before + self.adapter_layernorm_option = adapter_layernorm_option + + self.adapter_layer_norm_before = None + if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": + self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd) + + if adapter_scalar == "learnable_scalar": + self.scale = nn.Parameter(torch.ones(1)) + else: + self.scale = float(adapter_scalar) + + self.down_proj = nn.Linear(self.n_embd, 64) + self.scale_proj = nn.Linear(64, 64, bias=False) + self.non_linear_func = nn.ReLU() + self.up_proj = nn.Linear(self.down_size, self.n_embd) + + self.dropout = dropout + if init_option == "bert": + raise NotImplementedError + elif init_option == "lora": + with torch.no_grad(): + nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) + nn.init.zeros_(self.up_proj.weight) + nn.init.zeros_(self.down_proj.bias) + nn.init.zeros_(self.up_proj.bias) + + self.identity_matrix = torch.eye(self.scale_proj.weight.shape[1]) + self.input_matrix = None + self.space = [] + self.scale_param = nn.ParameterList() + + def enable_scale(self, space): + self.space = space + self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix).to(self.scale_proj.weight.device) for _ in self.space]) + + def disable_scale(self): + self.space = [] + self.scale_param = nn.ParameterList() + + def forward(self, x, add_residual=True, residual=None, compute_input_matrix=False): + + residual = x if residual is None else residual + if self.adapter_layernorm_option == 'in': # none + x = self.adapter_layer_norm_before(x) + + down = self.down_proj(x) + + if compute_input_matrix: + self.input_matrix = down.clone().detach().cpu() + + scale_proj_weight = self.scale_proj.weight + for scale, space in zip(self.scale_param, self.space): + + cropped_scale = scale[:space.shape[1], :space.shape[1]] + cropped_identity_matrix = self.identity_matrix[:space.shape[1], :space.shape[1]].to(self.scale_proj.weight.device) + + scale_proj_weight = scale_proj_weight + scale_proj_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + + down = F.linear(down, scale_proj_weight) + + down = self.non_linear_func(down) + down = nn.functional.dropout(down, p=self.dropout, training=self.training) + + up = self.up_proj(down) + up = up * self.scale + + if self.adapter_layernorm_option == 'out': # none + up = self.adapter_layer_norm_before(up) + + if add_residual: + output = up + residual + else: + output = up + + return output +''' + +class MaskedAdapter(Adapter): + def __init__(self, + d_model=None, + bottleneck=None, + dropout=0.0, + init_option="lora", + adapter_scalar="1.0", + adapter_layernorm_option="in"): + super().__init__(d_model, bottleneck, dropout, init_option, adapter_scalar, adapter_layernorm_option) + + self.down_proj = Linear_TRGP(self.n_embd, 64) + self.up_proj = Linear_TRGP(self.down_size, self.n_embd) + + def forward(self, x, add_residual=True, residual=None, compute_input_matrix=False): + + residual = x if residual is None else residual + if self.adapter_layernorm_option == 'in': # none + x = self.adapter_layer_norm_before(x) + + down = self.down_proj(x, compute_input_matrix) + down = self.non_linear_func(down) + + up = self.up_proj(down, compute_input_matrix) + up = up * self.scale + + if self.adapter_layernorm_option == 'out': # none + up = self.adapter_layer_norm_before(up) + + if add_residual: + output = up + residual + else: + output = up + + return output \ No newline at end of file diff --git a/core/model/backbone/petl/proj.py b/core/model/backbone/petl/proj.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea58b7f8095515736fd78c4709a19f90156d290 --- /dev/null +++ b/core/model/backbone/petl/proj.py @@ -0,0 +1,92 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Proj(nn.Module): + def __init__(self, + d_model=None, + id=-1): + super().__init__() + + self.eye = nn.Parameter(torch.eye(d_model)) + + self.space = [torch.tensor((1)), torch.tensor((1))] + self.scale_param = nn.ParameterList([nn.Parameter(self.eye) for _ in range(2)]) + self.scaling_mask = [False, False] + self.id = -1 + + def forward(self, x, kv_w, expert_id): + + if expert_id == self.id: + pass + else: + return F.linear(x, kv_w) + + pre_kv_w = None + + for mask, scale, space in zip(self.scaling_mask, self.scale_param, self.space): + + if not mask: + break + + scale_size = space.shape[1] + cropped_scale = scale[:scale_size, :scale_size] + + cropped_scale = cropped_scale @ cropped_scale.T # better, idk why + + cropped_identity_matrix = self.eye[:scale_size, :scale_size].to(x) + + if pre_kv_w is None: + pre_kv_w = kv_w + kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + else: + pre_kv_w = pre_kv_w + pre_kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + + if pre_kv_w is None: + return F.linear(x, kv_w) + else: + return F.linear(x, pre_kv_w) + +class Proj2(nn.Module): + def __init__(self, + d_model=None, + id=-1): + super().__init__() + + self.eye = nn.Parameter(torch.eye(d_model)) + + self.space = [torch.tensor((1)), torch.tensor((1))] + self.scale_param = nn.ParameterList([nn.Parameter(self.eye) for _ in range(2)]) + self.scaling_mask = [False, False] + self.id = -1 + + def forward(self, x, kv_w, expert_id): + + if expert_id == self.id: + pass + else: + return F.linear(x, kv_w) + + pre_kv_w = None + + for mask, scale, space in zip(self.scaling_mask, self.scale_param, self.space): + + if not mask: + break + + scale_size = space.shape[1] + cropped_scale = scale[:scale_size, :scale_size] + + cropped_scale = cropped_scale @ cropped_scale.T # better, idk why + + cropped_identity_matrix = self.eye[:scale_size, :scale_size].to(x) + + if pre_kv_w is None: + pre_kv_w = kv_w + kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + else: + pre_kv_w = pre_kv_w + pre_kv_w @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + + if pre_kv_w is None: + return F.linear(x, kv_w) + else: + return F.linear(x, pre_kv_w) \ No newline at end of file diff --git a/core/model/backbone/petl/vision_transformer_adapter.py b/core/model/backbone/petl/vision_transformer_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..eb80b83d61a5ace6e7039987a51ca003c973ac60 --- /dev/null +++ b/core/model/backbone/petl/vision_transformer_adapter.py @@ -0,0 +1,468 @@ +# -------------------------------------------------------- +# References: +# https://github.com/jxhe/unify-parameter-efficient-tuning +# -------------------------------------------------------- + +import math +import torch +import torch.nn as nn +from timm.models.layers import DropPath +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# MAE: https://github.com/facebookresearch/mae +# -------------------------------------------------------- +import timm +from functools import partial +from collections import OrderedDict +import torch +import torch.nn as nn +from timm.models.vision_transformer import PatchEmbed +from timm.models.registry import register_model + +import logging +import os +from collections import OrderedDict +import torch + + + +class Adapter(nn.Module): + def __init__(self, + config=None, + d_model=None, + bottleneck=None, + dropout=0.0, + init_option="bert", + adapter_scalar="1.0", + adapter_layernorm_option="in"): + super().__init__() + self.n_embd = config.d_model if d_model is None else d_model + self.down_size = config.attn_bn if bottleneck is None else bottleneck + + #_before + self.adapter_layernorm_option = adapter_layernorm_option + + self.adapter_layer_norm_before = None + if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": + self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd) + + if adapter_scalar == "learnable_scalar": + self.scale = nn.Parameter(torch.ones(1)) + else: + self.scale = float(adapter_scalar) + + self.down_proj = nn.Linear(self.n_embd, self.down_size) + self.non_linear_func = nn.ReLU() + self.up_proj = nn.Linear(self.down_size, self.n_embd) + + self.dropout = dropout + if init_option == "bert": + raise NotImplementedError + elif init_option == "lora": + with torch.no_grad(): + nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) + nn.init.zeros_(self.up_proj.weight) + nn.init.zeros_(self.down_proj.bias) + nn.init.zeros_(self.up_proj.bias) + + def forward(self, x, add_residual=True, residual=None): + residual = x if residual is None else residual + if self.adapter_layernorm_option == 'in': + x = self.adapter_layer_norm_before(x) + + down = self.down_proj(x) + down = self.non_linear_func(down) + down = nn.functional.dropout(down, p=self.dropout, training=self.training) + up = self.up_proj(down) + + up = up * self.scale + + if self.adapter_layernorm_option == 'out': + up = self.adapter_layer_norm_before(up) + + if add_residual: + output = up + residual + else: + output = up + + return output + + + + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward(self, x): + B, N, C = x.shape + + q = self.q_proj(x) + k = self._shape(self.k_proj(x), -1, B).view(B * self.num_heads, -1, self.head_dim) + v = self._shape(self.v_proj(x), -1, B).view(B * self.num_heads, -1, self.head_dim) + q = self._shape(q, N, B).view(B * self.num_heads, -1, self.head_dim) + + # attn = (q @ k.transpose(-2, -1)) * self.scale + attn_weights = torch.bmm(q, k.transpose(1, 2)) * self.scale + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = self.attn_drop(attn_weights) + attn_output = torch.bmm(attn_probs, v) + + attn_output = attn_output.view(B, self.num_heads, N, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(B, N, C) + + x = self.proj(attn_output) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, config=None, layer_id=None): + super().__init__() + self.config = config + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.fc1 = nn.Linear(dim, mlp_hidden_dim) + self.fc2 = nn.Linear(mlp_hidden_dim, dim) + self.act = act_layer() + self.mlp_drop = nn.Dropout(drop) + + if config.ffn_adapt: + self.adaptmlp = Adapter(self.config, dropout=0.1, bottleneck=config.ffn_num, + init_option=config.ffn_adapter_init_option, + adapter_scalar=config.ffn_adapter_scalar, + adapter_layernorm_option=config.ffn_adapter_layernorm_option, + ) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + if self.config.ffn_adapt and self.config.ffn_option == 'parallel': + adapt_x = self.adaptmlp(x, add_residual=False) + + residual = x + x = self.mlp_drop(self.act(self.fc1(self.norm2(x)))) + x = self.drop_path(self.mlp_drop(self.fc2(x))) + + if self.config.ffn_adapt: + if self.config.ffn_option == 'sequential': + x = self.adaptmlp(x) + elif self.config.ffn_option == 'parallel': + x = x + adapt_x + else: + raise ValueError(self.config.ffn_adapt) + + x = residual + x + return x + + + + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for global average pooling + """ + def __init__(self, global_pool=False, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, + act_layer=None, weight_init='', tuning_config=None): + super().__init__() + + + #print("I'm using ViT with adapters.") + self.tuning_config = tuning_config + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, + config=tuning_config, layer_id=i, + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + # self.init_weights(weight_init) + + ######### MAE begins ############ + self.global_pool = global_pool + if self.global_pool: + self.fc_norm = norm_layer(embed_dim) + + del self.norm # remove the original norm + + ######## Adapter begins ######### + if tuning_config.vpt_on: + assert tuning_config.vpt_num > 0, tuning_config.vpt_num + # properly registered + self.embeddings = nn.ParameterList( # batch, num_prompt, embed_dim + [nn.Parameter(torch.empty(1, self.tuning_config.vpt_num, embed_dim)) for _ in + range(depth)]) + for eee in self.embeddings: + torch.nn.init.xavier_uniform_(eee.data) + + def init_weights(self, mode=''): + raise NotImplementedError() + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for idx, blk in enumerate(self.blocks): + if self.tuning_config.vpt_on: + eee = self.embeddings[idx].expand(B, -1, -1) + x = torch.cat([eee, x], dim=1) + x = blk(x) + if self.tuning_config.vpt_on: + x = x[:, self.tuning_config.vpt_num:, :] + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + def forward(self, x): + x = self.forward_features(x,) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.head(x) + return x + + +# def vit_base_patch16(**kwargs): +# model = VisionTransformer( +# patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, +# norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) +# return model + + +# def vit_large_patch16(**kwargs): +# model = VisionTransformer( +# patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, +# norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) +# return model + + +# def vit_huge_patch14(**kwargs): +# model = VisionTransformer( +# patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, +# norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) +# return model + + +# def _create_vision_transformer(variant, pretrained=False, **kwargs): +# if kwargs.get('features_only', None): +# raise RuntimeError('features_only not implemented for Vision Transformer models.') + +# pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) +# model = build_model_with_cfg( +# VisionTransformer, variant, pretrained, +# pretrained_cfg=pretrained_cfg, +# pretrained_filter_fn=checkpoint_filter_fn, +# pretrained_custom_load='npz' in pretrained_cfg['url'], +# **kwargs) +# return model + + + + +def vit_base_patch16_224_adapter(pretrained=False, **kwargs): + + model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + + # checkpoint_model = torch.load('./pretrained_models/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz') + checkpoint_model=timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0) + state_dict = checkpoint_model.state_dict() + # modify the checkpoint state dict to match the model + # first, split qkv weight into q, k, v + for key in list(state_dict.keys()): + if 'qkv.weight' in key: + qkv_weight = state_dict.pop(key) + q_weight = qkv_weight[:768] + k_weight = qkv_weight[768:768*2] + v_weight = qkv_weight[768*2:] + state_dict[key.replace('qkv.weight', 'q_proj.weight')] = q_weight + state_dict[key.replace('qkv.weight', 'k_proj.weight')] = k_weight + state_dict[key.replace('qkv.weight', 'v_proj.weight')] = v_weight + elif 'qkv.bias' in key: + qkv_bias = state_dict.pop(key) + q_bias = qkv_bias[:768] + k_bias = qkv_bias[768:768*2] + v_bias = qkv_bias[768*2:] + state_dict[key.replace('qkv.bias', 'q_proj.bias')] = q_bias + state_dict[key.replace('qkv.bias', 'k_proj.bias')] = k_bias + state_dict[key.replace('qkv.bias', 'v_proj.bias')] = v_bias + # second, modify the mlp.fc.weight to match fc.weight + for key in list(state_dict.keys()): + if 'mlp.fc' in key: + fc_weight = state_dict.pop(key) + state_dict[key.replace('mlp.', '')] = fc_weight + + msg = model.load_state_dict(state_dict, strict=False) + print(msg) + + # s=model.state_dict() + # # print the keys in s + # for key in s.keys(): + # print(key) + # # print the keys in checkpoint_model + # for key in state_dict.keys(): + # if key in s.keys(): + # print(key, 'yes') + # else: + # print(key, 'NOOOOOOOOOOOOOOOOOOO') + + # freeze all but the adapter + for name, p in model.named_parameters(): + if name in msg.missing_keys: + p.requires_grad = True + else: + p.requires_grad = False + return model + + + +def vit_base_patch16_224_in21k_adapter(pretrained=False, **kwargs): + + model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + + # checkpoint_model = torch.load('./pretrained_models/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz') + checkpoint_model=timm.create_model("vit_base_patch16_224_in21k", pretrained=True, num_classes=0) + state_dict = checkpoint_model.state_dict() + # modify the checkpoint state dict to match the model + # first, split qkv weight into q, k, v + for key in list(state_dict.keys()): + if 'qkv.weight' in key: + qkv_weight = state_dict.pop(key) + q_weight = qkv_weight[:768] + k_weight = qkv_weight[768:768*2] + v_weight = qkv_weight[768*2:] + state_dict[key.replace('qkv.weight', 'q_proj.weight')] = q_weight + state_dict[key.replace('qkv.weight', 'k_proj.weight')] = k_weight + state_dict[key.replace('qkv.weight', 'v_proj.weight')] = v_weight + elif 'qkv.bias' in key: + qkv_bias = state_dict.pop(key) + q_bias = qkv_bias[:768] + k_bias = qkv_bias[768:768*2] + v_bias = qkv_bias[768*2:] + state_dict[key.replace('qkv.bias', 'q_proj.bias')] = q_bias + state_dict[key.replace('qkv.bias', 'k_proj.bias')] = k_bias + state_dict[key.replace('qkv.bias', 'v_proj.bias')] = v_bias + # second, modify the mlp.fc.weight to match fc.weight + for key in list(state_dict.keys()): + if 'mlp.fc' in key: + fc_weight = state_dict.pop(key) + state_dict[key.replace('mlp.', '')] = fc_weight + + msg = model.load_state_dict(state_dict, strict=False) + #print(msg) + + # s=model.state_dict() + # # print the keys in s + # for key in s.keys(): + # print(key) + # # print the keys in checkpoint_model + # for key in state_dict.keys(): + # if key in s.keys(): + # print(key, 'yes') + # else: + # print(key, 'NOOOOOOOOOOOOOOOOOOO') + + # freeze all but the adapter + for name, p in model.named_parameters(): + if name in msg.missing_keys: + p.requires_grad = True + else: + p.requires_grad = False + return model + diff --git a/core/model/backbone/petl/vision_transformer_ssf.py b/core/model/backbone/petl/vision_transformer_ssf.py new file mode 100644 index 0000000000000000000000000000000000000000..edfe320d4162d6139ed8c9fd558d6d8dac31896a --- /dev/null +++ b/core/model/backbone/petl/vision_transformer_ssf.py @@ -0,0 +1,872 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020, Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv, resolve_pretrained_cfg, checkpoint_seq +from timm.models.layers import DropPath, trunc_normal_, lecun_normal_, _assert +from timm.models.layers.helpers import to_2tuple +from timm.models.registry import register_model + + + +# import ipdb + + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'vit_tiny_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_tiny_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_base_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + + + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), + + +} + + + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., tuning_mode='ssf'): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) + + + def forward(self, x): + x = self.fc1(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.drop2(x) + + return x + + + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., tuning_mode='ssf'): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + + + + def forward(self, x): + B, N, C = x.shape + if self.tuning_mode == 'ssf': + qkv = (ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1)).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, tuning_mode='ssf'): + super().__init__() + self.dim = dim + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, tuning_mode=tuning_mode) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, tuning_mode=tuning_mode) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + + + + def forward(self, x): + if self.tuning_mode == 'ssf': + x = x + self.drop_path1(self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1)))) + x = x + self.drop_path2(self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))) + else: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostBlock(nn.Module): + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x): + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class ParallelBlock(nn.Module): + + def __init__( + self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + self.ffns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + + def _forward_jit(self, x): + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x): + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, tuning_mode='ssf'): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.norm_layer = norm_layer + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) + + if norm_layer: + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(embed_dim) + + + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + if self.norm_layer: + x = ssf_ada(self.norm(x), self.ssf_scale_2, self.ssf_shift_2) + else: + x = self.norm(x) + else: + x = self.norm(x) + return x + + + +def init_ssf_scale_shift(dim): + scale = nn.Parameter(torch.ones(dim)) + shift = nn.Parameter(torch.zeros(dim)) + + nn.init.normal_(scale, mean=1, std=.02) + nn.init.normal_(shift, std=.02) + + return scale, shift + + +def ssf_ada(x, scale, shift): + assert scale.shape == shift.shape + if x.shape[-1] == scale.shape[0]: + return x * scale + shift + elif x.shape[1] == scale.shape[0]: + return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) + else: + raise ValueError('the input tensor shape does not match the shape of the scale factor.') + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, + class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, tuning_mode='ssf'): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + + #print('Using Pre-trained ViT with Scale & Shift.') + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 if class_token else 0 + self.grad_checkpointing = False + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.tuning_mode = tuning_mode + tuning_mode_list = [tuning_mode] * depth + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(self.num_features) + + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, tuning_mode=tuning_mode_list[i]) + for i in range(depth)]) + + + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + + def forward_features(self, x): + x = self.patch_embed(x) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + x = self.norm(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + + return x + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): + """ ViT weight initialization, matching JAX (Flax) impl """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_moco(module: nn.Module, name: str = ''): + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_init_weights_vit(mode='jax', head_bias: float = 0.): + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in pretrained_cfg['url'], + **kwargs) + return model + + + +@register_model +def vit_tiny_patch16_224_ssf(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_384_ssf(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) @ 384x384. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + + + +@register_model +def vit_small_patch16_224_ssf(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_384_ssf(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + + + +@register_model +def vit_base_patch16_224_ssf(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_384_ssf(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + + +@register_model +def vit_large_patch16_224_ssf(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_384_ssf(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + + +@register_model +def vit_tiny_patch16_224_in21k_ssf(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_in21k_ssf(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_in21k_ssf(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224_in21k_ssf(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + + diff --git a/core/model/backbone/petl/vpt.py b/core/model/backbone/petl/vpt.py new file mode 100644 index 0000000000000000000000000000000000000000..f4686a5a60964ee1285bbff7a774dfe9d8994157 --- /dev/null +++ b/core/model/backbone/petl/vpt.py @@ -0,0 +1,145 @@ +import timm +import torch +import torch.nn as nn +from timm.models.vision_transformer import VisionTransformer, PatchEmbed + +def build_promptmodel(modelname='vit_base_patch16_224', Prompt_Token_num=10, VPT_type="Deep"): + + # VPT_type = "Deep" / "Shallow" + edge_size=224 + patch_size=16 + num_classes=1000 if modelname == 'vit_base_patch16_224' else 21843 + basic_model = timm.create_model(modelname, pretrained=True) + model = VPT_ViT(Prompt_Token_num=Prompt_Token_num,VPT_type=VPT_type) + # model.New_CLS_head(num_classes) + + # drop head.weight and head.bias + basicmodeldict=basic_model.state_dict() + basicmodeldict.pop('head.weight') + basicmodeldict.pop('head.bias') + + model.load_state_dict(basicmodeldict, False) + + model.head = torch.nn.Identity() + + model.Freeze() + + return model + + +class VPT_ViT(VisionTransformer): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, Prompt_Token_num=1, + VPT_type="Shallow", basic_state_dict=None): + + # Recreate ViT + super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, + embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, embed_layer=embed_layer, + norm_layer=norm_layer, act_layer=act_layer) + + #print('Using VPT model') + # load basic state_dict + if basic_state_dict is not None: + self.load_state_dict(basic_state_dict, False) + + self.VPT_type = VPT_type + if VPT_type == "Deep": + self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim)) + else: # "Shallow" + self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim)) + + def New_CLS_head(self, new_classes=15): + self.head = nn.Linear(self.embed_dim, new_classes) + + def Freeze(self): + for param in self.parameters(): + param.requires_grad = False + + self.Prompt_Tokens.requires_grad = True + try: + for param in self.head.parameters(): + param.requires_grad = True + except: + pass + + def UnFreeze(self): + for param in self.parameters(): + param.requires_grad = True + + def obtain_prompt(self): + prompt_state_dict = {'head': self.head.state_dict(), + 'Prompt_Tokens': self.Prompt_Tokens} + # print(prompt_state_dict) + return prompt_state_dict + + def load_prompt(self, prompt_state_dict): + try: + self.head.load_state_dict(prompt_state_dict['head'], False) + except: + print('head not match, so skip head') + else: + pass + #print('prompt head match') + + if self.Prompt_Tokens.shape == prompt_state_dict['Prompt_Tokens'].shape: + + # device check + Prompt_Tokens = nn.Parameter(prompt_state_dict['Prompt_Tokens'].cpu()) + Prompt_Tokens.to(torch.device(self.Prompt_Tokens.device)) + + self.Prompt_Tokens = Prompt_Tokens + + else: + print('\n !!! cannot load prompt') + print('shape of model req prompt', self.Prompt_Tokens.shape) + print('shape of model given prompt', prompt_state_dict['Prompt_Tokens'].shape) + print('') + + def forward_features(self, x): + x = self.patch_embed(x) + # print(x.shape,self.pos_embed.shape) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + + # concatenate CLS token + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + + if self.VPT_type == "Deep": + + Prompt_Token_num = self.Prompt_Tokens.shape[1] + + for i in range(len(self.blocks)): + # concatenate Prompt_Tokens + Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0) + # firstly concatenate + x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1) + num_tokens = x.shape[1] + # lastly remove, a genius trick + x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num] + + else: # self.VPT_type == "Shallow" + Prompt_Token_num = self.Prompt_Tokens.shape[1] + + # concatenate Prompt_Tokens + Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1) + x = torch.cat((x, Prompt_Tokens), dim=1) + num_tokens = x.shape[1] + # Sequntially procees + x = self.blocks(x)[:, :num_tokens - Prompt_Token_num] + + x = self.norm(x) + return x + + def forward(self, x): + + x = self.forward_features(x) + + # use cls token for cls head + # x = self.pre_logits(x[:, 0, :]) + x=x[:, 0, :] + + # x = self.head(x) + return x \ No newline at end of file diff --git a/core/model/backbone/prompt.py b/core/model/backbone/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd7721f74a74c6aaec741f0c22d8e9af26bc7b1 --- /dev/null +++ b/core/model/backbone/prompt.py @@ -0,0 +1,497 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{DBLP:conf/cvpr/SmithKGCKAPFK23, + author = {James Seale Smith and + Leonid Karlinsky and + Vyshnavi Gutta and + Paola Cascante{-}Bonilla and + Donghyun Kim and + Assaf Arbelle and + Rameswar Panda and + Rog{\'{e}}rio Feris and + Zsolt Kira}, + title = {CODA-Prompt: COntinual Decomposed Attention-Based Prompting for Rehearsal-Free + Continual Learning}, + booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, + {CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023}, + pages = {11909--11919}, + publisher = {{IEEE}}, + year = {2023} +} + +https://arxiv.org/abs/2211.13218 + +Adapted from https://github.com/GT-RIPL/CODA-Prompt +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +import torchvision.models as models +from torch.autograd import Variable +import numpy as np +import copy + +# source code from https://github.com/GT-RIPL/CODA-Prompt +class CodaPrompt(nn.Module): + def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768): + super().__init__() + self.task_count = 0 + self.emb_d = emb_d + self.key_d = key_dim + self.n_tasks = n_tasks + self._init_smart(emb_d, prompt_param) + + # e prompt init + for e in self.e_layers: + # for model saving/loading simplicity, we init the full parameters here + # however, please note that we reinit the new components at each task + # in the "spirit of continual learning", as we don't know how many tasks + # we will encounter at the start of the task sequence + # + # in the original paper, we used ortho init at the start - this modification is more + # fair in the spirit of continual learning and has little affect on performance + e_l = self.e_p_length + p = tensor_prompt(self.e_pool_size, e_l, emb_d) + k = tensor_prompt(self.e_pool_size, self.key_d) + a = tensor_prompt(self.e_pool_size, self.key_d) + p = self.gram_schmidt(p) + k = self.gram_schmidt(k) + a = self.gram_schmidt(a) + setattr(self, f'e_p_{e}',p) + setattr(self, f'e_k_{e}',k) + setattr(self, f'e_a_{e}',a) + + def _init_smart(self, emb_d, prompt_param): + + # prompt basic param + self.e_pool_size = int(prompt_param[0]) + self.e_p_length = int(prompt_param[1]) + self.e_layers = [0,1,2,3,4] + + # strenth of ortho penalty + self.ortho_mu = prompt_param[2] + + def process_task_count(self): + self.task_count += 1 + + # in the spirit of continual learning, we will reinit the new components + # for the new task with Gram Schmidt + # + # in the original paper, we used ortho init at the start - this modification is more + # fair in the spirit of continual learning and has little affect on performance + # + # code for this function is modified from: + # https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py + for e in self.e_layers: + K = getattr(self,f'e_k_{e}') + A = getattr(self,f'e_a_{e}') + P = getattr(self,f'e_p_{e}') + k = self.gram_schmidt(K) + a = self.gram_schmidt(A) + p = self.gram_schmidt(P) + setattr(self, f'e_p_{e}',p) + setattr(self, f'e_k_{e}',k) + setattr(self, f'e_a_{e}',a) + + # code for this function is modified from: + # https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py + def gram_schmidt(self, vv): + + def projection(u, v): + denominator = (u * u).sum() + + if denominator < 1e-8: + return None + else: + return (v * u).sum() / denominator * u + + # check if the tensor is 3D and flatten the last two dimensions if necessary + is_3d = len(vv.shape) == 3 + if is_3d: + shape_2d = copy.deepcopy(vv.shape) + vv = vv.view(vv.shape[0],-1) + + # swap rows and columns + vv = vv.T + + # process matrix size + nk = vv.size(1) + uu = torch.zeros_like(vv, device=vv.device) + + # get starting point + pt = int(self.e_pool_size / (self.n_tasks)) + s = int(self.task_count * pt) + f = int((self.task_count + 1) * pt) + if s > 0: + uu[:, 0:s] = vv[:, 0:s].clone() + for k in range(s, f): + redo = True + while redo: + redo = False + vk = torch.randn_like(vv[:,k]).to(vv.device) + uk = 0 + for j in range(0, k): + if not redo: + uj = uu[:, j].clone() + proj = projection(uj, vk) + if proj is None: + redo = True + print('restarting!!!') + else: + uk = uk + proj + if not redo: uu[:, k] = vk - uk + for k in range(s, f): + uk = uu[:, k].clone() + uu[:, k] = uk / (uk.norm()) + + # undo swapping of rows and columns + uu = uu.T + + # return from 2D + if is_3d: + uu = uu.view(shape_2d) + + return torch.nn.Parameter(uu) + + def forward(self, x_querry, l, x_block, train=False, task_id=None): + + # e prompts + e_valid = False + if l in self.e_layers: + e_valid = True + B, C = x_querry.shape + + K = getattr(self,f'e_k_{l}') + A = getattr(self,f'e_a_{l}') + p = getattr(self,f'e_p_{l}') + pt = int(self.e_pool_size / (self.n_tasks)) + s = int(self.task_count * pt) + f = int((self.task_count + 1) * pt) + + # freeze/control past tasks + if train: + if self.task_count > 0: + K = torch.cat((K[:s].detach().clone(),K[s:f]), dim=0) + A = torch.cat((A[:s].detach().clone(),A[s:f]), dim=0) + p = torch.cat((p[:s].detach().clone(),p[s:f]), dim=0) + else: + K = K[s:f] + A = A[s:f] + p = p[s:f] + else: + K = K[0:f] + A = A[0:f] + p = p[0:f] + + # with attention and cosine sim + # (b x 1 x d) * soft([1 x k x d]) = (b x k x d) -> attention = k x d + a_querry = torch.einsum('bd,kd->bkd', x_querry, A) + # # (b x k x d) - [1 x k x d] = (b x k) -> key = k x d + n_K = nn.functional.normalize(K, dim=1) + q = nn.functional.normalize(a_querry, dim=2) + aq_k = torch.einsum('bkd,kd->bk', q, n_K) + # (b x 1 x k x 1) * [1 x plen x k x d] = (b x plen x d) -> prompt = plen x k x d + P_ = torch.einsum('bk,kld->bld', aq_k, p) + + # select prompts + i = int(self.e_p_length/2) + Ek = P_[:,:i,:] + Ev = P_[:,i:,:] + + # ortho penalty + if train and self.ortho_mu > 0: + loss = ortho_penalty(K) * self.ortho_mu + loss += ortho_penalty(A) * self.ortho_mu + loss += ortho_penalty(p.view(p.shape[0], -1)) * self.ortho_mu + else: + loss = 0 + else: + loss = 0 + + # combine prompts for prefix tuning + if e_valid: + p_return = [Ek, Ev] + else: + p_return = None + + # return + return p_return, loss, x_block + +def ortho_penalty(t): + return ((t @t.T - torch.eye(t.shape[0]).cuda())**2).mean() + +# @article{wang2022dualprompt, +# title={DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning}, +# author={Wang, Zifeng and Zhang, Zizhao and Ebrahimi, Sayna and Sun, Ruoxi and Zhang, Han and Lee, Chen-Yu and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and others}, +# journal={European Conference on Computer Vision}, +# year={2022} +# } +class DualPrompt(nn.Module): + def __init__(self, emb_d, n_tasks, prompt_param, key_dim=768): + super().__init__() + self.task_count = 0 + self.emb_d = emb_d + self.key_d = key_dim + self.n_tasks = n_tasks + self._init_smart(emb_d, prompt_param) + + # g prompt init + for g in self.g_layers: + p = tensor_prompt(self.g_p_length, emb_d) + setattr(self, f'g_p_{g}',p) + + + # e prompt init + for e in self.e_layers: + p = tensor_prompt(self.e_pool_size, self.e_p_length, emb_d) + k = tensor_prompt(self.e_pool_size, self.key_d) + setattr(self, f'e_p_{e}',p) + setattr(self, f'e_k_{e}',k) + + def _init_smart(self, emb_d, prompt_param): + self.top_k = 1 + self.task_id_bootstrap = True + + # prompt locations + self.g_layers = [0,1] + self.e_layers = [2,3,4] + + # prompt pool size + self.g_p_length = int(prompt_param[2]) + self.e_p_length = int(prompt_param[1]) + self.e_pool_size = int(prompt_param[0]) + + def process_task_count(self): + self.task_count += 1 + + def forward(self, x_querry, l, x_block, train=False, task_id=None): + # e prompts + e_valid = False + if l in self.e_layers: + e_valid = True + B, C = x_querry.shape + K = getattr(self,f'e_k_{l}') # 0 based indexing here + p = getattr(self,f'e_p_{l}') # 0 based indexing here + # print(p.shape) + # cosine similarity to match keys/querries + n_K = nn.functional.normalize(K, dim=1) + q = nn.functional.normalize(x_querry, dim=1).detach() + cos_sim = torch.einsum('bj,kj->bk', q, n_K) + + if train: + # dual prompt during training uses task id + if self.task_id_bootstrap: + loss = (1.0 - cos_sim[:,task_id]).sum() + P_ = p[task_id].expand(len(x_querry),-1,-1) + else: + top_k = torch.topk(cos_sim, self.top_k, dim=1) + k_idx = top_k.indices + loss = (1.0 - cos_sim[:,k_idx]).sum() + P_ = p[k_idx] + else: + top_k = torch.topk(cos_sim, self.top_k, dim=1) + k_idx = top_k.indices + P_ = p[k_idx] + + # select prompts + if train and self.task_id_bootstrap: + i = int(self.e_p_length/2) + Ek = P_[:,:i,:].reshape((B,-1,self.emb_d)) + Ev = P_[:,i:,:].reshape((B,-1,self.emb_d)) + else: + i = int(self.e_p_length/2) + Ek = P_[:,:,:i,:].reshape((B,-1,self.emb_d)) + Ev = P_[:,:,i:,:].reshape((B,-1,self.emb_d)) + + # g prompts + g_valid = False + if l in self.g_layers: + g_valid = True + j = int(self.g_p_length/2) + p = getattr(self,f'g_p_{l}') # 0 based indexing here + P_ = p.expand(len(x_querry),-1,-1) + Gk = P_[:,:j,:] + Gv = P_[:,j:,:] + + # combine prompts for prefix tuning + if e_valid and g_valid: + Pk = torch.cat((Ek, Gk), dim=1) + Pv = torch.cat((Ev, Gv), dim=1) + p_return = [Pk, Pv] + elif e_valid: + p_return = [Ek, Ev] + elif g_valid: + p_return = [Gk, Gv] + loss = 0 + else: + p_return = None + loss = 0 + + + # return + if train: + return p_return, loss, x_block + else: + return p_return, 0, x_block + +# @inproceedings{wang2022learning, +# title={Learning to prompt for continual learning}, +# author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas}, +# booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, +# pages={139--149}, +# year={2022} +# } +class L2P(nn.Module): + + def __init__(self, length, prompt_init=nn.init.uniform_, prompt_key=False, + pool_size=None, top_k=None, num_layers=1, embed_dim=768): + + super().__init__() + self.length = length + self.prompt_init = prompt_init + self.pool_size = pool_size + self.top_k = top_k + self.num_layers = num_layers + self.embed_dim = embed_dim + + # Initialize prompt parameters + self.prompt = nn.Parameter( + torch.empty((self.num_layers, self.pool_size, self.length, embed_dim)) + ) + self.prompt_key = nn.Parameter( + torch.empty((self.pool_size, embed_dim)) + ) + self.prompt_init(self.prompt) + self.prompt_init(self.prompt_key) + + def forward(self, x_embed, cls_features=None): + + B, N, C = x_embed.shape + assert C == self.embed_dim + + # Normalize key features + prompt_key_norm = F.normalize(self.prompt_key, p=2, dim=-1, eps=1e-12) + x_embed_norm = F.normalize(cls_features, p=2, dim=-1, eps=1e-12) + + sim = x_embed_norm @ prompt_key_norm.T + _, idx = torch.topk(sim, self.top_k, dim=1) + + prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True) + + # Manually pad to pool_size, equivalent as jnp.unique() + prompt_id = F.pad(prompt_id, (0, self.pool_size - len(prompt_id)), "constant", prompt_id[0]) + id_counts = F.pad(id_counts, (0, self.pool_size - len(id_counts)), "constant", 0) + + _, major_idx = torch.topk(id_counts, self.top_k) + + major_prompt_id = prompt_id[major_idx] + idx = major_prompt_id.unsqueeze(0).repeat(B, 1) + + batched_prompt_raw = self.prompt[:, idx] + + batched_prompt = batched_prompt_raw.reshape( + batched_prompt_raw.shape[0], + batched_prompt_raw.shape[1], + -1, + batched_prompt_raw.shape[-1] + ) + + # Calculate pull constraint loss + batched_key_norm = prompt_key_norm[idx] + sim_pull = batched_key_norm * x_embed_norm.unsqueeze(1) + reduce_sim = torch.sum(sim_pull) / B + + return batched_prompt, reduce_sim + +# note - ortho init has not been found to help l2p/dual prompt +def tensor_prompt(a, b, c=None, ortho=False): + if c is None: + p = torch.nn.Parameter(torch.FloatTensor(a,b), requires_grad=True) + else: + p = torch.nn.Parameter(torch.FloatTensor(a,b,c), requires_grad=True) + if ortho: + nn.init.orthogonal_(p) + else: + nn.init.uniform_(p) + return p + + + +# @inproceedings{10.24963/ijcai.2024/456, +# author = {Hong, Chenxing and Jin, Yan and Kang, Zhiqi and Chen, Yizhou and Li, Mengke and Lu, Yang and Wang, Hanzi}, +# title = {Dynamically anchored prompting for task-imbalanced continual learning}, +# booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence}, +# year = {2025}, +# } +class DAP(nn.Module): + def __init__(self, length=5, embed_dim=768, embedding_key='mean', prompt_init='uniform', prompt_pool=False, + prompt_key=False, pool_size=None, top_k=None, batchwise_prompt=False, prompt_key_init='uniform',tasklength=10): + super().__init__() + + self.length = length + self.embed_dim = embed_dim + self.prompt_pool = prompt_pool + self.embedding_key = embedding_key + self.prompt_init = prompt_init + self.prompt_key = prompt_key + self.pool_size = pool_size + self.top_k = top_k + self.batchwise_prompt = batchwise_prompt + self.tasklength = tasklength + if self.prompt_pool: + prompt_pool_shape = (pool_size, length, embed_dim) + generalpromt = (top_k, length, embed_dim) + + if prompt_init == 'zero': + self.prompt = nn.Parameter(torch.zeros(prompt_pool_shape)) + self.taskprompt = nn.ParameterList([nn.Parameter(torch.zeros(top_k, length, embed_dim)) for _ in range(tasklength)]) # this is for taskid + self.generalprompt = nn.Parameter(torch.zeros(generalpromt)) + + elif prompt_init == 'uniform': + self.prompt = nn.Parameter(torch.randn(prompt_pool_shape)) + nn.init.uniform_(self.prompt, -1, 1) + self.taskprompt = nn.ParameterList([nn.Parameter(torch.zeros(top_k, length, embed_dim)) for _ in range(tasklength)]) # this is for taskid + for tp in self.taskprompt: + nn.init.uniform_(tp, -1, 1) + self.generalprompt = nn.Parameter(torch.randn(generalpromt)) + nn.init.uniform_(self.generalprompt, -1, 1) + + if prompt_key: + + key_shape = (pool_size, embed_dim) + if prompt_key_init == 'zero': + self.prompt_key = nn.Parameter(torch.zeros(key_shape)) + elif prompt_key_init == 'uniform': + self.prompt_key = nn.Parameter(torch.randn(key_shape)) + nn.init.uniform_(self.prompt_key, -1, 1) + else: + + prompt_mean = torch.mean(self.prompt, dim=1) + self.prompt_key = prompt_mean + + def l2_normalize(self, x, dim=None, epsilon=1e-12): + """Normalizes a given vector or matrix.""" + square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) + x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device))) + return x * x_inv_norm + + def forward(self, x_embed, prompt_mask=None, cls_features=None,taskid=None): + + out = dict() + + top_k, length, c = self.taskprompt[taskid].shape + batched_task_prompt_raw = self.taskprompt[taskid].reshape(top_k * length, c) + batched_task_prompt = batched_task_prompt_raw.unsqueeze(0).expand(x_embed.shape[0], -1, -1) + + batched_general_prompt_raw = self.generalprompt.reshape(top_k * length, c) + batched_general_prompt = batched_general_prompt_raw.unsqueeze(0).expand(x_embed.shape[0], -1, -1) + + + out['total_prompt_len'] = batched_task_prompt.shape[1] + out['prompted_embedding'] = torch.cat([batched_task_prompt, x_embed], dim=1) + + out['gen_total_prompt_len'] = batched_general_prompt.shape[1] + out['gen_prompted_embedding'] = torch.cat([batched_general_prompt, x_embed], dim=1) + return out \ No newline at end of file diff --git a/core/model/backbone/resnet.py b/core/model/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5a8675d0271f64a92704459fe787cf3172589b --- /dev/null +++ b/core/model/backbone/resnet.py @@ -0,0 +1,778 @@ +''' +Code Reference: +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +https://github.com/G-U-N/PyCIL/blob/master/convs/resnet.py +''' + +import copy +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn.parameter import Parameter + +__all__ = ['resnet18', 'resnet34', 'resnet50', 'cifar_resnet20', 'cifar_resnet32', 'cifar_resnet32_V2', 'resnet32_V2', 'resnet18_AML', 'CosineLinear', 'SplitCosineLinear'] + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None,args=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + + assert args is not None, "you should pass args to resnet" + if 'cifar' in args["dataset"] or '5-datasets' in args["dataset"]: + self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) + elif 'imagenet' in args["dataset"]: + if args["init_cls_num"] == args["inc_cls_num"]: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.out_dim = 512 * block.expansion + # self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + + self.neck = nn.ModuleList() + self.fc = nn.Linear(512, 20) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + x = self.conv1(x) + + x_1 = self.layer1(x) + x_2 = self.layer2(x_1) + x_3 = self.layer3(x_2) + x_4 = self.layer4(x_3) + + pooled = self.avgpool(x_4) + features = torch.flatten(pooled, 1) + + return { + 'fmaps': [x_1, x_2, x_3, x_4], + 'features': features + } + + def forward(self, x): + return self._forward_impl(x) + + def feature(self, x): + x = self.conv1(x) + + x_1 = self.layer1(x) + x_2 = self.layer2(x_1) + x_3 = self.layer3(x_2) + x_4 = self.layer4(x_3) + + pooled = self.avgpool(x_4) + features = torch.flatten(pooled, 1) + + return features + + @property + def last_conv(self): + if hasattr(self.layer4[-1], 'conv3'): + return self.layer4[-1].conv3 + else: + return self.layer4[-1].conv2 + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = ResNet(block, layers, **kwargs) + if pretrained: + raise NotImplementedError + + if 'cosine_fc' in kwargs['args'].keys() and kwargs['args']['cosine_fc']: + in_features = model.fc.in_features + out_features = model.fc.out_features + model.fc = CosineLinear(in_features, out_features) + return model + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + +class ResNetBasicblock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(ResNetBasicblock, self).__init__() + + self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(planes) + + self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + + self.downsample = downsample + + def forward(self, x): + residual = x + + basicblock = self.conv_a(x) + basicblock = self.bn_a(basicblock) + basicblock = F.relu(basicblock, inplace=True) + + basicblock = self.conv_b(basicblock) + basicblock = self.bn_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(x) + + return F.relu(residual + basicblock, inplace=True) + +''' +Code Reference: +https://github.com/G-U-N/PyCIL/blob/master/convs/resnet.py + +We keep this version ResNet to ensure that we can achieve better accuracy. +''' +class CifarResNet(nn.Module): + """ + ResNet optimized for the Cifar Dataset, as specified in + https://arxiv.org/abs/1512.03385.pdf + """ + + def __init__(self, block, depth, channels=3): + super(CifarResNet, self).__init__() + + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' + layer_blocks = (depth - 2) // 6 + + self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_1 = nn.BatchNorm2d(16) + + self.inplanes = 16 + self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) + self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) + self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) + self.avgpool = nn.AvgPool2d(8) + self.out_dim = 64 * block.expansion + # self.fc = nn.Linear(64*block.expansion, 20) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + m.bias.data.zero_() + + + self.neck = nn.ModuleList() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + # downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv_1_3x3(x) + x = F.relu(self.bn_1(x), inplace=True) + + x_1 = self.stage_1(x) + x_2 = self.stage_2(x_1) + x_3 = self.stage_3(x_2) + + pooled = self.avgpool(x_3) + features = pooled.view(pooled.size(0), -1) + + return { + 'fmaps': [x_1, x_2, x_3], + 'features': features + } + + def feature(self, x): + x = self.conv_1_3x3(x) + x = F.relu(self.bn_1(x), inplace=True) + + x_1 = self.stage_1(x) + x_2 = self.stage_2(x_1) + x_3 = self.stage_3(x_2) + + pooled = self.avgpool(x_3) + features = pooled.view(pooled.size(0), -1) + + return features + + @property + def last_conv(self): + return self.stage_3[-1].conv_b + +''' +Code Reference: +https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py +''' +class CosineLinear(nn.Module): + def __init__(self, in_features, out_features, sigma=True): + super(CosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.Tensor(out_features, in_features)) + if sigma: + self.sigma = Parameter(torch.Tensor(1)) + else: + self.register_parameter('sigma', None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.sigma is not None: + self.sigma.data.fill_(1) #for initializaiton of sigma + + def forward(self, input): + out = F.linear(F.normalize(input, p=2,dim=1), \ + F.normalize(self.weight, p=2, dim=1)) + if self.sigma is not None: + out = self.sigma * out + return out + +class SplitCosineLinear(nn.Module): + #consists of two fc layers and concatenate their outputs + def __init__(self, in_features, out_features1, out_features2, sigma=True): + super(SplitCosineLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features1 + out_features2 + self.fc1 = CosineLinear(in_features, out_features1, False) + self.fc2 = CosineLinear(in_features, out_features2, False) + if sigma: + self.sigma = Parameter(torch.Tensor(1)) + self.sigma.data.fill_(1) + else: + self.register_parameter('sigma', None) + + def forward(self, x): + out1 = self.fc1(x) + out2 = self.fc2(x) + out = torch.cat((out1, out2), dim=1) #concatenate along the channel + if self.sigma is not None: + out = self.sigma * out + return out + + +''' +Code Reference: +https://github.com/hshustc/CVPR19_Incremental_Learning + +The version of ResNet used in the official LUCIR repository, if not used, will lead to a decrease in performance. +''' + +class modified_BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, last=False): + super(modified_BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.last = last + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + if not self.last: #remove ReLU in the last layer + out = self.relu(out) + + return out + +class modified_ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=10): + self.inplanes = 16 + super(modified_ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(16) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(block, 16, layers[0]) + self.layer2 = self._make_layer(block, 32, layers[1], stride=2) + self.layer3 = self._make_layer(block, 64, layers[2], stride=2, last_phase=True) + self.avgpool = nn.AvgPool2d(8, stride=1) + # self.fc = modified_linear.CosineLinear(64 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, last_phase=False): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + if last_phase: + for i in range(1, blocks-1): + layers.append(block(self.inplanes, planes)) + layers.append(block(self.inplanes, planes, last=True)) + else: + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + + return {"features": x} + + def feature(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + + return x + +# Temporary Resnet for BIC, TODO: Merge +class BiasLayer(nn.Module): + + def __init__(self): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, requires_grad=True)) + self.beta = nn.Parameter(torch.zeros(1, requires_grad=True)) + + def forward(self, x): + return self.alpha * x + self.beta + +class BasicBlock2(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super().__init__() + self.bn1 = nn.BatchNorm2d(inplanes) + self.relu = nn.ReLU(inplace=True) + + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.relu(out) + out = self.conv1(out) + + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + +class ResNet_BIC(nn.Module): + + def __init__(self, depth, block_name='BasicBlock2'): + super().__init__() + # Model type specifies number of layers for CIFAR-10 model + if block_name.lower() == 'basicblock2': + assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' + n = (depth - 2) // 6 + block = BasicBlock2 + elif block_name.lower() == 'bottleneck': + assert 0, 'bottleneck is called, should not happen in method : BIC' + assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' + n = (depth - 2) // 9 + block = Bottleneck + else: + raise ValueError('block_name shoule be Basicblock or Bottleneck') + + self.inplanes = 16 + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False) + self.layer1 = self._make_layer(block, 16, n) + self.layer2 = self._make_layer(block, 32, n, stride=2) + self.layer3 = self._make_layer(block, 64, n, stride=2) + self.bn = nn.BatchNorm2d(64 * block.expansion) + self.relu = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(8) + + self.feat_dim = 256 # final feature's dim + #self.feat_dim = 3136 # ImageNet-R + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + x = self.layer1(x) # 32x32 + x = self.layer2(x) # 16x16 + x = self.layer3(x) # 8x8 + x = self.bn(x) + x = self.relu(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + + return x + +# Temporary Resnet for AML (ERACE, ERAML), TODO: Merge +class BasicBlock_AML(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, + stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + self.activation = nn.ReLU() + + def forward(self, x): + out = self.activation(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out = out + self.shortcut(x) + out = self.activation(out) + return out + +class ResNet_AML(nn.Module): + def __init__(self, block, num_blocks, num_classes, nf=20, input_size=(3, 32, 32)): + super().__init__() + self.in_planes = nf + self.input_size = input_size + + self.conv1 = conv3x3(input_size[0], nf * 1) + self.bn1 = nn.BatchNorm2d(nf * 1) + self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) + + self.activation = nn.ReLU() + + with torch.no_grad(): + dummy = torch.zeros(1, *self.input_size) + out = self.forward(dummy) + self.out_dim = out.view(1, -1).shape[1] + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = self.activation(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + return out.view(out.size(0), -1) + + +def cifar_resnet20(pretrained=False, **kwargs): + n = 3 + model = CifarResNet(ResNetBasicblock, 20) + return model + +def cifar_resnet32(pretrained=False, **kwargs): + # EWC + model = CifarResNet(ResNetBasicblock, 32) + return model + +def cifar_resnet32_V2(pretrained=False, **kwargs): + # BIC + return ResNet_BIC(32) + +def resnet32_V2(pretrained=False, **kwargs): + # ## LUCIR: + n = 5 + model = modified_ResNet(modified_BasicBlock, [n, n, n], num_classes=50) + return model + +def resnet18_AML(pretrained=False, **kwargs): + if 'input_size' not in kwargs.keys(): + kwargs['input_size'] = [3, 32, 32] + return ResNet_AML(BasicBlock_AML, [2, 2, 2, 2], kwargs['num_classes'], input_size = kwargs['input_size']) diff --git a/core/model/backbone/resnet_cbam.py b/core/model/backbone/resnet_cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0a69e4dd184836106e9c4a8c8a8d1c9c6346bf --- /dev/null +++ b/core/model/backbone/resnet_cbam.py @@ -0,0 +1,273 @@ +""" +Code Reference: +https://github.com/G-U-N/PyCIL/blob/master/convs/resnet_cbam.py +""" + +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam', + 'resnet152_cbam'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) + self.relu1 = nn.ReLU() + self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) + max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) + out = avg_out + max_out + return self.sigmoid(out) + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + assert kernel_size in (3, 7), 'kernel size must be 3 or 7' + padding = 3 if kernel_size == 7 else 1 + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + self.ca = ChannelAttention(planes) + self.sa = SpatialAttention() + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.ca = ChannelAttention(planes * 4) + self.sa = SpatialAttention() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.ca(out) * out + out = self.sa(out) * out + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=100, args=None): + self.inplanes = 64 + super(ResNet, self).__init__() + assert args is not None, "you should pass args to resnet" + if 'cifar' in args["dataset"]: + self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) + elif 'imagenet' in args["dataset"]: + if args["init_cls"] == args["increment"]: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.feature = nn.AvgPool2d(4, stride=1) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + self.out_dim = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + dim = x.size()[-1] + pool = nn.AvgPool2d(dim, stride=1) + x = pool(x) + x = x.view(x.size(0), -1) + return x + +def resnet18_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet18']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + + +def resnet34_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet34']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet50_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet50']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet101_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet101']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model + + +def resnet152_cbam(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + pretrained_state_dict = model_zoo.load_url(model_urls['resnet152']) + now_state_dict = model.state_dict() + now_state_dict.update(pretrained_state_dict) + model.load_state_dict(now_state_dict) + return model \ No newline at end of file diff --git a/core/model/backbone/tokenizer/bpe_simple_vocab_16e6.txt.gz b/core/model/backbone/tokenizer/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/core/model/backbone/tokenizer/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/core/model/backbone/tokenizer/tokenizer.py b/core/model/backbone/tokenizer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7a2e07f8f7be842068a0da8bd0d1296bb0172629 --- /dev/null +++ b/core/model/backbone/tokenizer/tokenizer.py @@ -0,0 +1,140 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text \ No newline at end of file diff --git a/core/model/backbone/transformer.py b/core/model/backbone/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..562cb39c0388d33f72cb48bdeb0bd62ffcaedcec --- /dev/null +++ b/core/model/backbone/transformer.py @@ -0,0 +1,2768 @@ +''' +Code Reference: + +* https://github.com/jadore801120/attention-is-all-you-need-pytorch/ +* https://github.com/GT-RIPL/CODA-Prompt +* https://github.com/openai/CLIP +''' + +import os +import math +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from functools import partial +from collections import Counter +from timm.models.vision_transformer import PatchEmbed +from timm.models.layers import trunc_normal_, DropPath +from scipy.special import softmax + +from .petl.adapter import Adapter, MaskedAdapter +from .petl.proj import Proj +from .prompt import L2P + +# Helper +class SparseDispatcher(object): + """Helper for implementing a mixture of experts. + The purpose of this class is to create input minibatches for the + experts and to combine the results of the experts to form a unified + output tensor. + There are two functions: + dispatch - take an input Tensor and create input Tensors for each expert. + combine - take output Tensors from each expert and form a combined output + Tensor. Outputs from different experts for the same batch element are + summed together, weighted by the provided "gates". + The class is initialized with a "gates" Tensor, which specifies which + batch elements go to which experts, and the weights to use when combining + the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. + The inputs and outputs are all two-dimensional [batch, depth]. + Caller is responsible for collapsing additional dimensions prior to + calling this class and reshaping the output to the original shape. + See common_layers.reshape_like(). + Example use: + gates: a float32 `Tensor` with shape `[batch_size, num_experts]` + inputs: a float32 `Tensor` with shape `[batch_size, input_size]` + experts: a list of length `num_experts` containing sub-networks. + dispatcher = SparseDispatcher(num_experts, gates) + expert_inputs = dispatcher.dispatch(inputs) + expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] + outputs = dispatcher.combine(expert_outputs) + The preceding code sets the output for a particular example b to: + output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) + This class takes advantage of sparsity in the gate matrix by including in the + `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. + """ + + def __init__(self, num_experts, gates): + """Create a SparseDispatcher.""" + + self._gates = gates + self._num_experts = num_experts + + sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) + + # drop indices + _, self._expert_index = sorted_experts.split(1, dim=1) + # get according batch index for each expert + self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] + # calculate num samples that each expert gets + self._part_sizes = (gates > 0).sum(0).tolist() + # expand gates to match with self._batch_index + gates_exp = gates[self._batch_index.flatten()] + self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) + + def dispatch(self, inp): + """Create one input Tensor for each expert. + The `Tensor` for a expert `i` contains the slices of `inp` corresponding + to the batch elements `b` where `gates[b, i] > 0`. + Args: + inp: a `Tensor` of shape "[batch_size, ]` + Returns: + a list of `num_experts` `Tensor`s with shapes + `[expert_batch_size_i, ]`. + """ + + # assigns samples to experts whose gate is nonzero + + inp_exp = inp[self._batch_index].squeeze(1) + return torch.split(inp_exp, self._part_sizes, dim=0) + + def combine(self, expert_out, multiply_by_gates=True): + """Sum together the expert output, weighted by the gates. + The slice corresponding to a particular batch element `b` is computed + as the sum over all experts `i` of the expert output, weighted by the + corresponding gate values. If `multiply_by_gates` is set to False, the + gate values are ignored. + Args: + expert_out: a list of `num_experts` `Tensor`s, each with shape + `[expert_batch_size_i, ]`. + multiply_by_gates: a boolean + Returns: + a `Tensor` with shape `[batch_size, ]`. + """ + # apply exp to expert outputs, so we are not longer in log space + + stitched = torch.cat(expert_out, 0) + if multiply_by_gates: + stitched = stitched.mul(self._nonzero_gates) # 加权 + + zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), device=stitched.device) + # combine samples that have been processed by the same k experts + + combined = zeros.index_add(0, self._batch_index, stitched.float()) + # add eps to all zero values in order to avoid nans when going back to log space + # back to log space + return combined + + def expert_to_gates(self): + """Gate values corresponding to the examples in the per-expert `Tensor`s. + Returns: + a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` + and shapes `[expert_batch_size_i]` + """ + # split nonzero gates for each expert + return torch.split(self._nonzero_gates, self._part_sizes, dim=0) + +# Sub-module of Attention +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +# Attention +class MultiHeadAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0. else nn.Identity() + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity() + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, attn_mask=None, register_hook=False, prompt=None): + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # [3, B, NH, N, HD] + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + if prompt is not None: + pk, pv = prompt + pk = pk.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + pv = pv.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = torch.cat((pk,k), dim=2) + v = torch.cat((pv,v), dim=2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class MultiHeadAttention_LoRA(MultiHeadAttention): + + ''' + Attention module with lora, apply to k, v + ''' + + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) + + self.lora_rank = lora_rank + + self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) + self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) + self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) + self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) + self.apply_lora = False + + self.cur_matrix = torch.zeros(self.dim ,self.dim) + self.n_cur_matrix = 0 + + def init_param(self): + + nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B_k.weight) + nn.init.zeros_(self.lora_B_v.weight) + + self.apply_lora = True + + def merge_weight(self): + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight + v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight + self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0) + self.apply_lora = False + + def reset_input_matrix(self): + self.cur_matrix.zero_() + self.n_cur_matrix = 0 + + def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False): + + if get_input_matrix: + self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1]) + self.n_cur_matrix += x.shape[0]*x.shape[1] + + B, N, C = x.shape + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + if self.apply_lora: + k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight + v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class MultiHeadAttention_SDLoRA(MultiHeadAttention): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) + + self.lora_rank = lora_rank + self.lora_bias = lora_bias + + self.lora_A_q_list = nn.ModuleList([]) + self.lora_B_q_list = nn.ModuleList([]) + self.lora_A_v_list = nn.ModuleList([]) + self.lora_B_v_list = nn.ModuleList([]) + + self.assimilated_mag_lora_q = [] + self.assimilated_mag_lora_v = [] + + def init_param(self): + + self.lora_A_q_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) + self.lora_B_q_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) + self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) + self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) + + nn.init.kaiming_uniform_(self.lora_A_q_list[-1].weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B_q_list[-1].weight) + nn.init.zeros_(self.lora_B_v_list[-1].weight) + + self.assimilated_mag_lora_q.append( + torch.Tensor([0.0]).to(self.qkv.weight.device) + ) + self.assimilated_mag_lora_v.append( + torch.Tensor([0.0]).to(self.qkv.weight.device) + ) + + assert len(self.lora_A_q_list) == len(self.mag_lora) + assert len(self.mag_lora) == len(self.assimilated_mag_lora_q) + + def forward(self, x, attn_mask=None, register_hook=False, prompt=None): + + B, N, C = x.shape + + qq = self.mag_lora[-1] * self.lora_B_q_list[-1](self.lora_A_q_list[-1](x)) + vv = self.mag_lora[-1] * self.lora_B_v_list[-1](self.lora_A_v_list[-1](x)) + + for i in range(len(self.lora_A_q_list) - 1): + + norm_B = torch.norm(self.lora_B_q_list[i].weight) + norm_A = torch.norm(self.lora_A_q_list[i].weight) + + if norm_B != 0 and norm_A != 0: # Only in SD-LoRA-KD, where direction of lora being decomposed + qq += (self.mag_lora[i] + self.assimilated_mag_lora_q[i]) * self.lora_B_q_list[i](self.lora_A_q_list[i](x)) / (norm_B * norm_A) + + norm_B = torch.norm(self.lora_B_v_list[i].weight) + norm_A = torch.norm(self.lora_A_v_list[i].weight) + + if norm_B != 0 and norm_A != 0: + vv += (self.mag_lora[i] + self.assimilated_mag_lora_v[i]) * self.lora_B_v_list[i](self.lora_A_v_list[i](x)) / (norm_B * norm_A) + + qkv = self.qkv(x) + qkv[:, :, : self.dim] += qq + qkv[:, :, -self.dim :] += vv + + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class MultiHeadAttention_LoRA_Sub(MultiHeadAttention): + + ''' + Attention module with lora, apply to k, v + ''' + + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) + + self.lora_rank = lora_rank + + self.lora_A_k = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) + self.lora_B_k = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) + self.lora_A_v = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) + self.lora_B_v = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) + self.apply_lora = False + + self.cur_matrix = torch.zeros(self.dim ,self.dim) + self.n_cur_matrix = 0 + + self.register_buffer("prev_k_weight", torch.zeros(self.dim, self.dim)) + self.register_buffer("prev_v_weight", torch.zeros(self.dim, self.dim)) + + def init_param(self): + + nn.init.kaiming_uniform_(self.lora_A_k.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A_v.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B_k.weight) + nn.init.zeros_(self.lora_B_v.weight) + + self.apply_lora = True + + def save_weight(self): + + self.prev_k_weight += self.lora_B_k.weight @ self.lora_A_k.weight + self.prev_v_weight += self.lora_B_v.weight @ self.lora_A_v.weight + self.apply_lora = False + + def reset_input_matrix(self): + self.cur_matrix.zero_() + self.n_cur_matrix = 0 + + def forward(self, x, attn_mask=None, register_hook=False, prompt=None, get_input_matrix = False): + + B, N, C = x.shape + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + if get_input_matrix: + # Only in before_task getting input matrix + self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0] * x.shape[1]) + self.n_cur_matrix += x.shape[0]*x.shape[1] + + k_weight = k_weight - self.prev_k_weight + v_weight = v_weight - self.prev_v_weight + + elif self.apply_lora: + # Only in training + k_weight = k_weight + self.prev_k_weight + self.lora_B_k.weight @ self.lora_A_k.weight + v_weight = v_weight + self.prev_v_weight + self.lora_B_v.weight @ self.lora_A_v.weight + else: + # Only in testing + k_weight = k_weight + self.prev_k_weight + v_weight = v_weight + self.prev_v_weight + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class MultiHeadAttention_CL_LoRA(MultiHeadAttention_LoRA): + + ''' + Attention module with lora, apply to q, v + ''' + + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + + del self.lora_A_k + del self.lora_B_k + self.lora_A_q = nn.Linear(self.dim, self.lora_rank, bias=lora_bias) + self.lora_B_q = nn.Linear(self.lora_rank, self.dim, bias=lora_bias) + + def init_param(self): + + q1, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank)) + q2, _ = torch.linalg.qr(torch.rand(self.dim, self.lora_rank)) + with torch.no_grad(): + self.lora_A_q.weight.copy_(q1.T) + self.lora_A_v.weight.copy_(q2.T) + + scaling_factor = 1. # You can adjust this value if needed + self.lora_A_q.weight.data *= scaling_factor + self.lora_A_v.weight.data *= scaling_factor + + nn.init.zeros_(self.lora_B_q.weight) + nn.init.zeros_(self.lora_B_v.weight) + + def forward( + self, + x, + adapt=None, + prompt=None, + rank_prompt=None, + block_weight=None, + attn_mask=None, + register_hook=False): + + # custom_adapt including the lora and lora scale weight + # since this method has many set of weights during training/inference, keep changing the module weight is quite exhausting + # lets just pass in as argument + + B, N, C = x.shape + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data) + + if adapt is not None: + if block_weight is not None: + block_weight = block_weight + else: + block_weight = torch.ones(3).to(x.device) + qq = block_weight[0] * adapt[0](x) + vv = block_weight[2] * adapt[2](x) + + qkv[:, :, : self.dim] += qq + qkv[:, :, -self.dim :] += vv + + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +# MInfLoRA +class MultiHeadAttention_MaskedLoRA(MultiHeadAttention_LoRA): + + # Attention module with masked (projection) lora, apply to k, v + + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + + # Trgp implementation + self.identity_matrix = torch.eye(self.qkv.weight.shape[1]) + + self.space = [[0, 0] for _ in range(10)] + self.scale_param = nn.ModuleList([nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(2)]) for _ in range(10)]) + self.scaling_mask = [[False, False] for _ in range(10)] + + def enable_scale(self, task_id, space): + if len(space) == 2: + self.space[task_id][0] = space[0] + self.space[task_id][1] = space[1] + self.scaling_mask[task_id][0] = True + self.scaling_mask[task_id][1] = True + elif len(space) == 1: + self.space[task_id][0] = space[0] + self.scaling_mask[task_id][0] = True + + for scale_param_list in self.scale_param: + for scale_param in scale_param_list: + scale_param = scale_param.to(self.qkv.weight.device) + + def forward(self, x, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix = False): + + if get_input_matrix: + self.cur_matrix = (self.cur_matrix*self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0]*x.shape[1]) + self.n_cur_matrix += x.shape[0]*x.shape[1] + + B, N, C = x.shape + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + if self.apply_lora: + k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight + v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight + + for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): + + if not mask: + break + + scale_size = space.shape[1] + cropped_scale = scale[:scale_size, :scale_size] + + cropped_scale = cropped_scale @ cropped_scale.T # better, idk why + + cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) + + k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +# MInfLoRA1 +class MultiHeadAttention_MaskedLoRA1(MultiHeadAttention): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) + + self.cur_task = -1 + self.lora_rank = lora_rank + + self.cur_matrix = torch.zeros(self.dim ,self.dim) + self.n_cur_matrix = 0 + + self.lora_bias = lora_bias + + self.lora_A_k_list = nn.ModuleList([]) + self.lora_B_k_list = nn.ModuleList([]) + self.lora_A_v_list = nn.ModuleList([]) + self.lora_B_v_list = nn.ModuleList([]) + + self.space_k = [0 for _ in range(10)] + self.space_v = [0 for _ in range(10)] + self.identity_matrix = torch.eye(self.qkv.weight.shape[1]) + self.scale_param = nn.ParameterList([]) + + def init_param(self): + + self.lora_A_k_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) + self.lora_B_k_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) + self.lora_A_v_list.append(nn.Linear(self.dim, self.lora_rank, bias=self.lora_bias)) + self.lora_B_v_list.append(nn.Linear(self.lora_rank, self.dim, bias=self.lora_bias)) + self.scale_param.append(nn.Parameter(self.identity_matrix)) + + nn.init.kaiming_uniform_(self.lora_A_k_list[-1].weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A_v_list[-1].weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B_k_list[-1].weight) + nn.init.zeros_(self.lora_B_v_list[-1].weight) + + self.cur_task += 1 + + def reset_input_matrix(self): + self.cur_matrixs = [] + + def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): + + if get_input_matrix: + assert x.shape[0] < 512 + self.cur_matrixs.append(x.detach()) + + if x.shape[0] > 128: + # do some drift check here + activation = torch.bmm(x.permute(0, 2, 1), x).sum(dim=0) / x.shape[0] + + # get the intersect between previous activation and curr activation + activation = self.lora_A_k_list[-1].weight.data.T @ self.lora_A_k_list[-1].weight.data @ activation + + if self.cur_task > 0: + activation = activation - self.feature_mat @ activation + + U, _, _ = torch.linalg.svd(activation, full_matrices = False) + A_new = U[:,:self.lora_rank].T / math.sqrt(3) + A_old = self.lora_A_k_list[-1].weight.data + Bk_old = self.lora_B_k_list[-1].weight.data + Bv_old = self.lora_B_v_list[-1].weight.data + + tmp = A_old @ torch.pinverse(A_new) + Bk_new = Bk_old @ tmp + Bv_new = Bv_old @ tmp + + ''' + # Compute matmul results + Bk_old_A_old = Bk_old @ A_old + Bk_new_A_new = Bk_new @ A_new + Bv_old_A_old = Bv_old @ A_old + Bv_new_A_new = Bv_new @ A_new + + # Compute the Frobenius norm of the difference between old and new matmul results + frobenius_norm_Bk = torch.norm(Bk_old_A_old - Bk_new_A_new, p='fro') + frobenius_norm_Bv = torch.norm(Bv_old_A_old - Bv_new_A_new, p='fro') + + # Printing the results + print(f"Frobenius norm difference between Bk_old @ A_old and Bk_new @ A_new: {frobenius_norm_Bk.item()}") + print(f"Frobenius norm difference between Bv_old @ A_old and Bv_new @ A_new: {frobenius_norm_Bv.item()}") + ''' + + self.lora_A_k_list[-1].weight.data.copy_(A_new) + self.lora_A_v_list[-1].weight.data.copy_(A_new) + self.lora_B_k_list[-1].weight.data.copy_(Bk_new) + self.lora_B_v_list[-1].weight.data.copy_(Bv_new) + + B, N, C = x.shape + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + for ii in range(self.cur_task): + k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight + v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight + + k_weight = k_weight + self.lora_B_k_list[-1].weight @ self.lora_A_k_list[-1].weight + v_weight = v_weight + self.lora_B_v_list[-1].weight @ self.lora_A_v_list[-1].weight + + ''' + for ii in range(self.cur_task): + if not isinstance(self.space_k[ii], int): + + space_k = self.space_k[ii] + space_v = self.space_v[ii] + scale_k = self.scale_param[ii] + + # Q Scaling + scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] + + # QQ^T Scaling + scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T + + # QQ^T Diagonal Scaling12 + #scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T)) + + # Q Diagonal Scaling + #scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]])) + + #scalee = scale_k[0, 0] + scalee = self.mag_lora[ii] + + use_scale = False + if use_scale: + + norm_B = torch.norm(self.lora_B_k_list[ii].weight) + norm_A = torch.norm(self.lora_A_k_list[ii].weight) + + k_weight = k_weight - self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k + k_weight = k_weight + scalee * (self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight @ space_k.T @ space_k) / (norm_B * norm_A) + + norm_B = torch.norm(self.lora_B_v_list[ii].weight) + norm_A = torch.norm(self.lora_A_v_list[ii].weight) + + v_weight = v_weight - self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v + v_weight = v_weight + scalee * (self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight @ space_v.T @ space_v) / (norm_B * norm_A) + ''' + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x, x, probs + +# MInfLoRA2 +class MultiHeadAttention_MultiMaskedLoRA(MultiHeadAttention_MaskedLoRA): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + + self.activated_expert = 0 + self.saved_space = [[torch.tensor((1)), torch.tensor((1))] for _ in range(10)] + + self.hit = 0 + self.total = 0 + self.projected_cur_matrix = torch.zeros(self.dim ,self.dim) + self.n_projected_cur_matrix = 0 + + def reset_input_matrix(self): + super().reset_input_matrix() + self.projected_cur_matrix.zero_() + self.n_projected_cur_matrix = 0 + + def enable_scale(self, task_id, space): + + if len(space) == 2: + self.space[task_id][0] = space[0] + self.space[task_id][1] = space[1] + self.scaling_mask[task_id][0] = True + self.scaling_mask[task_id][1] = True + elif len(space) == 1: + self.space[task_id][0] = space[0] + self.scaling_mask[task_id][0] = True + + for scale_param_list in self.scale_param: + for scale_param in scale_param_list: + scale_param = scale_param.to(self.qkv.weight.device) + + def save_space(self, task_id, space): + self.activated_expert = task_id + self.saved_space[task_id][0] = space + + def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): + + B, N, C = x.shape + + if get_input_matrix: + assert expert_id == 0 + self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) + self.n_cur_matrix += B * N + + # By Sum and Batch + if not self.training and not get_input_matrix: + with torch.no_grad(): + + cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0) / (B * N) # (C, C) + saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) # (task_num, C, r) + #saved = torch.stack([self.space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) # (task_num, C, r) + + proj_mat = saved.transpose(1, 2) # (task_num, r, C) + proj_mat = torch.einsum('ijk,kl->ijl', proj_mat, cur_cur_matrix) # (task_num, r, C) @ (C, C) + + proj_norm = np.linalg.norm(proj_mat.cpu(), axis=(1, 2)) # (task_num, ) + + proj_norm = softmax(proj_norm) + probs.append(proj_norm) + selected_expert_id = np.argmax(proj_norm, axis = 0) # (task_num, ) + + expert_id = selected_expert_id + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + if self.apply_lora: + k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight + v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + # ---- + + for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): + + if not mask: + break + + scale_size = space.shape[1] + cropped_scale = scale[:scale_size, :scale_size] + + cropped_scale = cropped_scale @ cropped_scale.T # better, idk why + + cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) + + k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + + qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) + x_proj = self.proj(x_proj) + x_proj = self.proj_drop(x_proj) + + return x, x_proj, probs + + def forward1(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): + + B, N, C = x.shape + + if get_input_matrix: + assert expert_id == 0 + self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) + self.n_cur_matrix += B * N + + # By each + if not self.training and not get_input_matrix: + with torch.no_grad(): + + cur_cur_matrix = torch.bmm(x.detach().permute(0, 2, 1), x.detach()) / N # (B, C, C) + cur_cur_matrix = cur_cur_matrix.permute(1, 2, 0) # (C, C, B) + saved = torch.stack([self.saved_space[idd][0] for idd in range(self.activated_expert + 1)]).to(x.device) # (task_num, C, r) + proj_mat = saved.transpose(1, 2) # (task_num, r, C) + + proj_mat = torch.einsum('ijk,klm->ijlm', proj_mat, cur_cur_matrix) # (task_num, r, C) @ (C, C, B) -> (task_num, r, C, B) + + proj_norm = np.linalg.norm(proj_mat, axis=(1, 2)) # (task_num, B) + proj_norm = softmax(proj_norm, axis=0) # (task_num, B) + + probs.append(proj_norm) + + selected_expert_id = np.argmax(proj_norm, axis = 0) # (B, ) + selected_expert_id = torch.tensor(selected_expert_id).to(x.device) + + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + if self.apply_lora: + k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight + v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + # ---- + if not self.training and not get_input_matrix: # Test + inputs = [x_proj.clone() for _ in range(self.activated_expert + 1)] + k_weights = [k_weight.clone() for _ in range(self.activated_expert + 1)] + v_weights = [v_weight.clone() for _ in range(self.activated_expert + 1)] + qkv_outputs = [] + + for ex in range(self.activated_expert + 1): + + for mask, scale, space in zip(self.scaling_mask[ex], self.scale_param[ex], self.space[ex]): + + if not mask: + break + + scale_size = space.shape[1] + cropped_scale = scale[:scale_size, :scale_size] + + cropped_scale = cropped_scale @ cropped_scale.T # better, idk why + + cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(x.device) + + k_weights[ex] = k_weights[ex] + k_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + v_weights[ex] = v_weights[ex] + v_weights[ex] @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + + cur_selected = selected_expert_id.unsqueeze(-1).unsqueeze(-1) + + mask = (cur_selected == ex) + inputs[ex] *= mask + + inputs[ex] = inputs[ex].to(x.device) + q_weight = q_weight.to(x.device) + k_weights[ex] = k_weights[ex].to(x.device) + v_weights[ex] = v_weights[ex].to(x.device) + + qkv = F.linear(inputs[ex], torch.cat([q_weight, k_weights[ex], v_weights[ex]], dim=0)) + qkv_outputs.append(qkv) + + stacked = torch.stack(qkv_outputs) + qkv = torch.sum(stacked, dim=0) + qkv = qkv + self.qkv.bias + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) + x_proj = self.proj(x_proj) + x_proj = self.proj_drop(x_proj) + + else: + + for mask, scale, space in zip(self.scaling_mask[expert_id], self.scale_param[expert_id], self.space[expert_id]): + + if not mask: + break + + scale_size = space.shape[1] + cropped_scale = scale[:scale_size, :scale_size] + + cropped_scale = cropped_scale @ cropped_scale.T # better, idk why + + cropped_identity_matrix = self.identity_matrix[:scale_size, :scale_size].to(self.qkv.weight.device) + + k_weight = k_weight + k_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + v_weight = v_weight + v_weight @ space @ (cropped_scale - cropped_identity_matrix) @ space.T + + qkv = F.linear(x_proj, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x_proj = (attn @ v).transpose(1, 2).reshape(B, N, C) + x_proj = self.proj(x_proj) + x_proj = self.proj_drop(x_proj) + + return x, x_proj, probs + +# MInfLoRA3 +class MultiHeadAttention_MultiMaskedLoRA3(MultiHeadAttention_MaskedLoRA): + def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., lora_rank=10, lora_bias=False): + super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + + self.cur_task = -1 + + self.lora_A_k_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)]) + self.lora_B_k_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)]) + self.lora_A_v_list = nn.ModuleList([nn.Linear(self.dim, self.lora_rank, bias=lora_bias) for _ in range(10)]) + self.lora_B_v_list = nn.ModuleList([nn.Linear(self.lora_rank, self.dim, bias=lora_bias) for _ in range(10)]) + + self.space_k = [0 for _ in range(10)] + self.space_v = [0 for _ in range(10)] + self.scale_param = nn.ParameterList([nn.Parameter(self.identity_matrix) for _ in range(10)]) + + def init_param(self): + + self.cur_task += 1 + + i = self.cur_task + + nn.init.kaiming_uniform_(self.lora_A_k_list[i].weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A_v_list[i].weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B_k_list[i].weight) + nn.init.zeros_(self.lora_B_v_list[i].weight) + + def merge_weight(self): + + print('Not MERGED') + return 0 + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + k_weight = k_weight + self.lora_B_k.weight @ self.lora_A_k.weight + v_weight = v_weight + self.lora_B_v.weight @ self.lora_A_v.weight + + self.apply_lora = False + + for exp_id in range(10): + for ii, mask, scale_k, scale_v, space_k, space_v in zip([0, 1], self.scaling_mask[exp_id], self.scale_param_k[exp_id], self.scale_param_v[exp_id], self.space_k[exp_id], self.space_v[exp_id]): + + if isinstance(space_k, int): + break + + k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ space_k.T @ scale_k[:space_k.shape[0], :space_k.shape[0]] @ space_k + v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ space_v.T @ scale_k[:space_v.shape[0], :space_v.shape[0]] @ space_v + + self.space_k[exp_id][ii] = 0 + + self.qkv.weight.data = torch.cat([q_weight, k_weight, v_weight], dim=0) + + def save_dir(self): + + return 0 + + self.cur_task += 1 + + ''' + + norm = torch.linalg.matrix_norm(self.lora_B_k.weight @ self.lora_A_k.weight) + + self.lora_A_k.weight.data = self.lora_A_k.weight.data / norm + self.lora_B_k.weight.data = self.lora_B_k.weight.data / norm + + self.space_k[self.cur_task][0] = self.lora_A_k.weight.data.clone() / norm + + norm = torch.linalg.matrix_norm(self.lora_B_v.weight @ self.lora_A_v.weight) + + self.lora_A_v.weight.data = self.lora_A_v.weight.data / norm + self.lora_B_v.weight.data = self.lora_B_v.weight.data / norm + + self.space_v[self.cur_task][0] = self.lora_A_v.weight.data.clone() / norm] + ''' + + _, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + U, _, _ = np.linalg.svd(k_weight.data, full_matrices = False) + U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False) + orto_proj = U[:, -50:] + + self.space_k[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device) + + U, _, _ = np.linalg.svd(v_weight.data, full_matrices = False) + U, _, _ = np.linalg.svd(U[:, :10], full_matrices = False) + orto_proj = U[:, -50:] + + self.space_v[self.cur_task][0] = torch.Tensor(orto_proj.T).to(self.qkv.weight.device) + + self.scaling_mask[self.cur_task][0] = True + + def enable_scale(self, task_id, space): + + if len(space) == 2: + self.space[task_id][0] = space[0] + self.space[task_id][1] = space[1] + self.scaling_mask[task_id][0] = True + self.scaling_mask[task_id][1] = True + elif len(space) == 1: + self.space[task_id][0] = space[0] + self.scaling_mask[task_id][0] = True + + for scale_param_list in self.scale_param: + for scale_param in scale_param_list: + scale_param = scale_param.to(self.qkv.weight.device) + + def save_space(self, task_id, space): + self.activated_expert = task_id + self.saved_space[task_id].append(space) + + def forward(self, x, x_proj, probs, attn_mask=None, expert_id=0, register_hook=False, prompt=None, get_input_matrix=False): + + B, N, C = x.shape + + if get_input_matrix: + self.cur_matrix = (self.cur_matrix * self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + B * N) + self.n_cur_matrix += B * N + + q_weight, k_weight, v_weight = self.qkv.weight.chunk(3, dim=0) + + # DEBUG + for exp_id in range(10): + + break + + for mask, scale, space_k, space_v in zip(self.scaling_mask[exp_id], self.scale_param[exp_id], self.space_k[exp_id], self.space_v[exp_id]): + + if isinstance(space_k, int): + break + + cropped_scale = scale[:space_k.shape[0], :space_k.shape[0]] + print( + round(torch.linalg.norm(k_weight @ space_k.T @ space_k, ord='fro').item(), 2), + round(torch.linalg.norm(k_weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2), + round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ space_k, ord='fro').item(), 2), + round(torch.linalg.norm(self.lora_B_k.weight @ self.lora_A_k.weight @ space_k.T @ cropped_scale @ space_k, ord='fro').item(), 2), + ) + + for ii in range(self.cur_task + 1): + k_weight = k_weight + self.lora_B_k_list[ii].weight @ self.lora_A_k_list[ii].weight + v_weight = v_weight + self.lora_B_v_list[ii].weight @ self.lora_A_v_list[ii].weight + + if not isinstance(self.space_k[ii], int): + + space_k = self.space_k[ii] + space_v = self.space_v[ii] + scale_k = self.scale_param[ii] + + # Q Scaling + scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] #@ scale_k[:space_k.shape[0], :space_k.shape[0]].T + + # QQ^T Scaling + scalee = scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T + + # QQ^T Diagonal Scaling + scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]] @ scale_k[:space_k.shape[0], :space_k.shape[0]].T)) + + # Q Diagonal Scaling + scalee = torch.diag(torch.diag(scale_k[:space_k.shape[0], :space_k.shape[0]])) + + # TODO: Change the Scale to remove scale, and scale by magnitude and direction + # TODO2: following TODO 1, but now unfreeze the previous scale + + use_scale = True + if use_scale: + #print('Enabled scale') + #magnitude = torch.linalg.matrix_norm(space_k, ord='fro') + dir_k = space_k # / magnitude + k_weight = k_weight - k_weight @ space_k.T @ space_k + k_weight @ dir_k.T @ scalee @ dir_k + + #magnitude = torch.linalg.matrix_norm(space_v, ord='fro') + dir_v = space_v # / magnitude + + v_weight = v_weight - v_weight @ space_v.T @ space_v + v_weight @ dir_v.T @ scalee @ dir_v + else: + pass + #print('Disabled scale') + + #if not self.training and not get_input_matrix: + # diagonal_elements = torch.diag(scalee) + # print(ii, diagonal_elements) + + qkv = F.linear(x, torch.cat([q_weight, k_weight, v_weight], dim=0), self.qkv.bias.data).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if attn_mask is not None: + attn += attn_mask.unsqueeze(0) # For head axis broadcasting + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x, x, probs + + +# MLP +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +# Blocks +class ResidualAttentionBlock(nn.Module): + def __init__(self, + d_model: int, + n_head: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + drop_path: float = 0., + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + norm_layer_eps = 1e-5, + attn_mask: torch.Tensor = None, + text_or_image=None, + # For attn_layer = MultiHeadAttention_LoRA + lora_rank: int = 0, + lora_bias: bool = False + ): + super().__init__() + + if attn_layer == MultiHeadAttention: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) + elif attn_layer == MultiHeadAttention_LoRA: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + elif attn_layer == MultiHeadAttention_SDLoRA: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + elif attn_layer == MultiHeadAttention_LoRA_Sub: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + elif attn_layer == MultiHeadAttention_MaskedLoRA: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + elif attn_layer == MultiHeadAttention_MultiMaskedLoRA: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + elif attn_layer == MultiHeadAttention_CL_LoRA: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + else: + assert 0, f'{attn_layer} not Implemented' + + self.ln_1 = norm_layer(d_model, eps=norm_layer_eps) + self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) + self.ln_2 = norm_layer(d_model, eps=norm_layer_eps) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attn_mask = attn_mask + self.text_or_image = text_or_image + + def attention(self, x: torch.Tensor, **kwargs): + self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None + + x = x.permute(1, 0, 2) + attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) + attn = attn.permute(1, 0, 2) + + return attn + + def forward(self, x: torch.Tensor, **kwargs): + + x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim] + x = x + self.drop_path(self.mlp(self.ln_2(x))) + + return x + +class ResidualAttentionBlock_MLP(ResidualAttentionBlock): + def __init__(self, + d_model: int, + n_head: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + drop_path: float = 0., + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + # For attn_layer = MultiHeadAttention_LoRA + lora_rank: int = 0, + lora_bias: bool = False, + ): + super().__init__( + d_model, + n_head, + mlp_ratio, + qkv_bias, + qk_scale, + attn_drop, + proj_drop, + drop_path, + attn_layer, + act_layer, + norm_layer, + attn_mask, + text_or_image) + + self.ffn_num = 64 + self.adaptmlp = Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, + init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none') + + self.lora_feature = None # Temporary save the output of adapter, for method : DMNSP + + def attention(self, x: torch.Tensor, **kwargs): + self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None + + x = x.permute(1, 0, 2) + attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) + attn = attn.permute(1, 0, 2) + + return attn + + def forward(self, x: torch.Tensor, compute_lora_feat = False, **kwargs): + + x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim] + + x_re = x.permute(1, 0, 2) + adapt_x = self.adaptmlp(x_re, add_residual=False) + adapt_x = adapt_x.permute(1, 0, 2) + + x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) + + if compute_lora_feat: + self.lora_feature = adapt_x.detach().cpu() + + return x + +class ResidualAttentionBlock_MaskedMLP(ResidualAttentionBlock): + def __init__(self, + d_model: int, + n_head: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + drop_path: float = 0., + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + # For attn_layer = MultiHeadAttention_LoRA + lora_rank: int = 0, + lora_bias: bool = False, + ): + super().__init__( + d_model, + n_head, + mlp_ratio, + qkv_bias, + qk_scale, + attn_drop, + proj_drop, + drop_path, + attn_layer, + act_layer, + norm_layer, + attn_mask, + text_or_image) + + self.ffn_num = 64 + self.adaptmlp = MaskedAdapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, + init_option='lora', adapter_scalar=0.1, adapter_layernorm_option='none') + + def attention(self, x: torch.Tensor, **kwargs): + self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None + + x = x.permute(1, 0, 2) + attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) + attn = attn.permute(1, 0, 2) + + return attn + + def forward(self, x: torch.Tensor, compute_input_matrix = False, **kwargs): + + x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim] + + x_re = x.permute(1, 0, 2) + adapt_x = self.adaptmlp(x_re, add_residual=False, compute_input_matrix=compute_input_matrix) + adapt_x = adapt_x.permute(1, 0, 2) + + x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) + + return x + +class ResidualAttentionBlock_MoE_MLP(ResidualAttentionBlock): + def __init__(self, + d_model: int, + n_head: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + drop_path: float = 0., + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + # For attn_layer = MultiHeadAttention_LoRA + lora_rank: int = 0, + lora_bias: bool = False, + # MoE + step: int = 0, + experts_num: int = 0, + top_k: int = 0, + noisy_gating: bool = True + ): + super().__init__( + d_model, + n_head, + mlp_ratio, + qkv_bias, + qk_scale, + attn_drop, + proj_drop, + drop_path, + attn_layer, + act_layer, + norm_layer, + attn_mask, + text_or_image) + + assert top_k <= experts_num + + self.register_buffer("mean", torch.tensor([0.0])) + self.register_buffer("std", torch.tensor([1.0])) + self.step = step + self.top_k = top_k + self.noisy_gating = noisy_gating + + self.ffn_num = 64 + self.experts_num = experts_num + self.softmax = nn.Softmax(1) + self.softplus = nn.Softplus() + + self.router_list = nn.ParameterList([ + nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step) + ]) + self.w_noise_list = nn.ParameterList([ + nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True) for _ in range(self.step) + ]) + + self.adaptmlp_list = nn.ModuleList([ + Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, + init_option='lora', + adapter_scalar=0.1, + adapter_layernorm_option='none') + for _ in range(self.experts_num) + ]) + + self.lora_feature = None # Temporary save the output of adapter, for method : DMNSP + + + def attention(self, x: torch.Tensor, **kwargs): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + + x = x.permute(1, 0, 2) + attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) + attn = attn.permute(1, 0, 2) + + return attn + + def cv_squared(self, x): + """The squared coefficient of variation of a sample. + Useful as a loss to encourage a positive distribution to be more uniform. + Epsilons added for numerical stability. + Returns 0 for an empty Tensor. + Args: + x: a `Tensor`. + Returns: + a `Scalar`. + """ + eps = 1e-10 + # if only num_experts = 1 + + if x.shape[0] == 1: + return torch.tensor([0], device=x.device, dtype=x.dtype) + return x.float().var() / (x.float().mean()**2 + eps) + + def _gates_to_load(self, gates): + """Compute the true load per expert, given the gates. + The load is the number of examples for which the corresponding gate is >0. + Args: + gates: a `Tensor` of shape [batch_size, n] + Returns: + a float32 `Tensor` of shape [n] + """ + return (gates > 0).sum(0) + + def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): + """Helper function to NoisyTopKGating. + Computes the probability that value is in top k, given different random noise. + This gives us a way of backpropagating from a loss that balances the number + of times each expert is in the top k experts per example. + In the case of no noise, pass in None for noise_stddev, and the result will + not be differentiable. + Args: + clean_values: a `Tensor` of shape [batch, n]. + noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus + normally distributed noise with standard deviation noise_stddev. + noise_stddev: a `Tensor` of shape [batch, n], or None + noisy_top_values: a `Tensor` of shape [batch, m]. + "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 + Returns: + a `Tensor` of shape [batch, n]. + """ + # print('1231',clean_values) # 全nan + batch = clean_values.size(0) + m = noisy_top_values.size(1) + top_values_flat = noisy_top_values.flatten() + + threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k + threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) + is_in = torch.gt(noisy_values, threshold_if_in) + threshold_positions_if_out = threshold_positions_if_in - 1 + threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) + # is each value currently in the top k. + normal = Normal(self.mean, self.std) + # + + prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) + prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) + prob = torch.where(is_in, prob_if_in, prob_if_out) + return prob + + def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2): + """Noisy top-k gating. + See paper: https://arxiv.org/abs/1701.06538. + Args: + x: input Tensor with shape [batch_size, input_size] + train: a boolean - we only add noise at training time. + noise_epsilon: a float + Returns: + gates: a Tensor with shape [batch_size, num_experts] + load: a Tensor with shape [num_experts] + """ + + clean_logits = x @ w_gate.to(x) + + if self.noisy_gating and train: + raw_noise_stddev = x @ w_noise.to(x) + noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) + noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) + logits = noisy_logits + else: + logits = clean_logits + # calculate topk + 1 that will be needed for the noisy gates + top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1) + top_k_logits = top_logits[:, :self.top_k] + top_k_indices = top_indices[:, :self.top_k] + top_k_gates = self.softmax(top_k_logits) + zeros = torch.zeros_like(logits) + gates = zeros.scatter(1, top_k_indices, top_k_gates) + #if self.noisy_gating and self.top_k < self.experts_num and train: # 目前未用上 + # load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) + #else: + # load = self._gates_to_load(gates) + return gates, None #, load + + def forward(self, x: torch.Tensor, compute_lora_feat=False, **kwargs): + + x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim] + + x_re = x.permute(1, 0, 2)[:, 0, :] + gates, load = self.noisy_top_k_gating(x_re, self.training, self.router_list[0], + self.w_noise_list[0]) # hardcoded, task_id = 0 + + dispatcher = SparseDispatcher(self.experts_num, gates) + expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1)) + + expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0], + x.shape[0], x.shape[2]).to(x), add_residual=False) + for i in range(self.experts_num)] + + expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0] + + y = dispatcher.combine(expert_outputs) + y = y.view(x.shape[1], x.shape[0], x.shape[2]) + x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2)) + + return x + +class ResidualAttentionBlock_MoE_Proj(ResidualAttentionBlock): + def __init__(self, + d_model: int, + n_head: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + drop_path: float = 0., + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + # For attn_layer = MultiHeadAttention_LoRA + lora_rank: int = 0, + lora_bias: bool = False, + # MoE + experts_num=0, + ): + super().__init__() + + if isinstance(attn_layer, str): + try: + attn_layer = globals()[attn_layer] + except KeyError: + print(f'{attn_layer} not found, using default MultiHeadAttention') + attn_layer = MultiHeadAttention + + if isinstance(act_layer, str): + try: + act_layer = globals()[act_layer] + except KeyError: + print(f'{act_layer} not found, using default nn.GELU') + act_layer = nn.GELU + + if isinstance(norm_layer, str): + try: + norm_layer = globals()[norm_layer] + except KeyError: + print(f'{norm_layer} not found, using default nn.LayerNorm') + norm_layer = nn.LayerNorm + + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) + self.ln_1 = norm_layer(d_model) + self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) + self.ln_2 = norm_layer(d_model) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attn_mask = attn_mask + self.is_train = True + # TODO : make it argument, now harcodrd + if experts_num > 1: + self.register_buffer("mean", torch.tensor([0.0])) + self.register_buffer("std", torch.tensor([1.0])) + self.step = 1 + else: + self.step = 0 + self.top_k = 2 + self.ffn_num = 64 + self.experts_num = experts_num + self.softmax = nn.Softmax(1) + self.softplus = nn.Softplus() + self.noisy_gating = True + self.text_or_image = text_or_image + self.router_list = nn.ParameterList() + self.w_noise_list = nn.ParameterList() + + for i in range(self.step): + self.router_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True)) + self.w_noise_list.append(nn.Parameter(torch.zeros(d_model, self.experts_num), requires_grad=True)) + + self.adaptmlp_list = nn.ModuleList() + for i in range(self.experts_num): # + self.adaptmlp_list.append(Adapter(d_model=d_model, dropout=0.1, bottleneck=self.ffn_num, + init_option='lora', + adapter_scalar=0.1, + adapter_layernorm_option='none', + )) + + self.lora_feature = None # Temporary save the output of adapter, for method : DMNSP + + def attention(self, x: torch.Tensor, **kwargs): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + + x = x.permute(1, 0, 2) + attn = self.attn(x, attn_mask=self.attn_mask, **kwargs) + attn = attn.permute(1, 0, 2) + + return attn + + def cv_squared(self, x): + """The squared coefficient of variation of a sample. + Useful as a loss to encourage a positive distribution to be more uniform. + Epsilons added for numerical stability. + Returns 0 for an empty Tensor. + Args: + x: a `Tensor`. + Returns: + a `Scalar`. + """ + eps = 1e-10 + # if only num_experts = 1 + + if x.shape[0] == 1: + return torch.tensor([0], device=x.device, dtype=x.dtype) + return x.float().var() / (x.float().mean()**2 + eps) + + def _gates_to_load(self, gates): + """Compute the true load per expert, given the gates. + The load is the number of examples for which the corresponding gate is >0. + Args: + gates: a `Tensor` of shape [batch_size, n] + Returns: + a float32 `Tensor` of shape [n] + """ + return (gates > 0).sum(0) + + def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): + """Helper function to NoisyTopKGating. + Computes the probability that value is in top k, given different random noise. + This gives us a way of backpropagating from a loss that balances the number + of times each expert is in the top k experts per example. + In the case of no noise, pass in None for noise_stddev, and the result will + not be differentiable. + Args: + clean_values: a `Tensor` of shape [batch, n]. + noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus + normally distributed noise with standard deviation noise_stddev. + noise_stddev: a `Tensor` of shape [batch, n], or None + noisy_top_values: a `Tensor` of shape [batch, m]. + "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 + Returns: + a `Tensor` of shape [batch, n]. + """ + # print('1231',clean_values) # 全nan + batch = clean_values.size(0) + m = noisy_top_values.size(1) + top_values_flat = noisy_top_values.flatten() + + threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.top_k + threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) + is_in = torch.gt(noisy_values, threshold_if_in) + threshold_positions_if_out = threshold_positions_if_in - 1 + threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) + # is each value currently in the top k. + normal = Normal(self.mean, self.std) + # + + prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) + prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) + prob = torch.where(is_in, prob_if_in, prob_if_out) + return prob + + def noisy_top_k_gating(self, x, train, w_gate, w_noise, noise_epsilon=1e-2): + """Noisy top-k gating. + See paper: https://arxiv.org/abs/1701.06538. + Args: + x: input Tensor with shape [batch_size, input_size] + train: a boolean - we only add noise at training time. + noise_epsilon: a float + Returns: + gates: a Tensor with shape [batch_size, num_experts] + load: a Tensor with shape [num_experts] + """ + + clean_logits = x @ w_gate.to(x) + if self.noisy_gating and train: + raw_noise_stddev = x @ w_noise.to(x) + noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) + noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) + logits = noisy_logits + else: + logits = clean_logits + # calculate topk + 1 that will be needed for the noisy gates + top_logits, top_indices = logits.topk(min(self.top_k + 1, self.experts_num), dim=1) + top_k_logits = top_logits[:, :self.top_k] + top_k_indices = top_indices[:, :self.top_k] + top_k_gates = self.softmax(top_k_logits) + zeros = torch.zeros_like(logits) + gates = zeros.scatter(1, top_k_indices, top_k_gates) + #if self.noisy_gating and self.top_k < self.experts_num and train: # 目前未用上 + # load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) + #else: + # load = self._gates_to_load(gates) + return gates, None #, load + + def forward(self, x: torch.Tensor, **kwargs): + + x = x + self.drop_path(self.attention(self.ln_1(x), **kwargs)) # [Seq, Batch, Dim] + + if self.experts_num == 0: + + x = x + self.drop_path(self.mlp(self.ln_2(x))) + + elif self.experts_num == 1: + + x_re = x.permute(1, 0, 2) + adapt_x = self.adaptmlp_list[0](x_re, add_residual=False) + adapt_x = adapt_x.permute(1, 0, 2) + + x = x + self.drop_path(self.mlp(self.ln_2(x)) + adapt_x) + + if compute_lora_feat: + self.lora_feature = adapt_x.detach().cpu() + + else: + + x_re = x.permute(1, 0, 2)[:, 0, :] + gates, load = self.noisy_top_k_gating(x_re, self.is_train, self.router_list[0], + self.w_noise_list[0]) # hardcoded, task_id = 0 + + dispatcher = SparseDispatcher(self.experts_num, gates) + expert_inputs = dispatcher.dispatch(x.permute(1, 0, 2).view(x.shape[1], -1)) + + expert_outputs = [self.adaptmlp_list[i](expert_inputs[i].view(expert_inputs[i].shape[0], + x.shape[0], x.shape[2]).to(x), add_residual=False) + for i in range(self.experts_num)] + + expert_outputs = [out.view(out.shape[0], -1) for out in expert_outputs if out.shape[0] > 0] + + y = dispatcher.combine(expert_outputs) + y = y.view(x.shape[1], x.shape[0], x.shape[2]) + x = x + self.drop_path(self.mlp(self.ln_2(x)) + y.permute(1, 0, 2)) + + return x + +class ResidualAttentionBiBlock(nn.Module): + def __init__(self, + d_model: int, + n_head: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + drop_path: float = 0., + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + # For attn_layer = MultiHeadAttention_LoRA + lora_rank: int = 0, + lora_bias: bool = False + ): + super().__init__() + + if attn_layer == MultiHeadAttention: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop) + elif attn_layer == MultiHeadAttention_LoRA: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + elif attn_layer == MultiHeadAttention_MaskedLoRA: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + elif attn_layer == MultiHeadAttention_MultiMaskedLoRA or attn_layer == MultiHeadAttention_MultiMaskedLoRA3 or attn_layer == MultiHeadAttention_MaskedLoRA1: + self.attn = attn_layer(d_model, n_head, qkv_bias, qk_scale, attn_drop, proj_drop, lora_rank, lora_bias) + else: + assert 0, f'{attn_layer} not Implemented' + + self.ln_1 = norm_layer(d_model) + self.mlp = Mlp(d_model, int(d_model * mlp_ratio), act_layer=act_layer) + self.ln_2 = norm_layer(d_model) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attn_mask = attn_mask + self.text_or_image = text_or_image + + def attention(self, x: torch.Tensor, x_proj, probs, **kwargs): + + self.attn_mask = self.attn_mask.to(x) if self.attn_mask is not None else None + + x, x_proj = x.permute(1, 0, 2), x_proj.permute(1, 0, 2) + attn, attn_proj, probs = self.attn(x, x_proj, probs, attn_mask=self.attn_mask, **kwargs) + attn, attn_proj = attn.permute(1, 0, 2), attn_proj.permute(1, 0, 2) + + return attn, attn_proj, probs + + def forward(self, x: torch.Tensor, x_proj, probs, **kwargs): + + attn, attn_proj, probs = self.attention(self.ln_1(x), self.ln_1(x_proj), probs, **kwargs) + + x = x + self.drop_path(attn) # [Seq, Batch, Dim] + x_proj = x_proj + self.drop_path(attn_proj) + + x = x + self.drop_path(self.mlp(self.ln_2(x))) + x_proj = x_proj + self.drop_path(self.mlp(self.ln_2(x_proj))) + + return x, x_proj, probs + +# Transformers +class Transformer(nn.Module): + def __init__(self, + width: int, + layers: int, + heads: int, + block_layer = ResidualAttentionBlock, + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + **kwargs + ): + super().__init__() + self.width = width + self.layers = layers + + if isinstance(block_layer, str): + try: + block_layer = globals()[block_layer] + except KeyError: + print(f'{block_layer} not found, using default ResidualAttentionBlock') + block_layer = ResidualAttentionBlock + + if isinstance(attn_layer, str): + try: + attn_layer = globals()[attn_layer] + except KeyError: + print(f'{attn_layer} not found, using default MultiHeadAttention') + attn_layer = MultiHeadAttention + + if isinstance(act_layer, str): + try: + act_layer = globals()[act_layer] + except KeyError: + print(f'{act_layer} not found, using default nn.GELU') + act_layer = nn.GELU + + if isinstance(norm_layer, str): + try: + norm_layer = globals()[norm_layer] + except KeyError: + print(f'{norm_layer} not found, using default nn.LayerNorm') + norm_layer = nn.LayerNorm + + self.blocks = nn.ModuleList([ + block_layer( + d_model=width, + n_head=heads, + attn_layer=attn_layer, + act_layer=act_layer, + norm_layer=norm_layer, + attn_mask=attn_mask, + text_or_image=text_or_image, + **kwargs) + for _ in range(layers)]) + + def forward(self, x: torch.Tensor, l2p_prompt=None, l2p_e_prompt_layer_idx=[], **kwargs): + + prompt_counter = -1 + for i, block in enumerate(self.blocks): + if l2p_prompt is not None and (i in l2p_e_prompt_layer_idx): + prompt_counter += 1 + batched_prompt = l2p_prompt[prompt_counter] + batched_prompt = batched_prompt.permute(1, 0, 2) # (B, Prompt_len, C) -> (Prompt_len, B, C), since x is also (N, B, C) + x = torch.cat([batched_prompt, x], dim=0) # append to dim N + + x = block(x, **kwargs) + + return x + +class Transformer_Proj(Transformer): + def __init__(self, + width: int, + layers: int, + heads: int, + block_layer = ResidualAttentionBlock, + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + **kwargs + ): + super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs) + self.probs = [] + + def forward(self, x: torch.Tensor, **kwargs): + + x_proj = x.clone() + self.probs = [] + for i, block in enumerate(self.blocks): + x, x_proj, self.probs = block(x, x_proj, self.probs, **kwargs) + + return x_proj + +class Transformer_CL_LoRA(Transformer): + def __init__(self, + width: int, + layers: int, + heads: int, + block_layer = ResidualAttentionBlock, + attn_layer = MultiHeadAttention, + act_layer = nn.GELU, + norm_layer = nn.LayerNorm, + attn_mask: torch.Tensor = None, + text_or_image=None, + **kwargs + ): + super().__init__(width, layers, heads, block_layer, attn_layer, act_layer, norm_layer, attn_mask, text_or_image, **kwargs) + + def forward(self, x, adapt, prompt, rank_prompt, block_weight, **kwargs): + + for idx, blk in enumerate(self.blocks): + + if idx >= 6: + x = blk( + x, + adapt = adapt[idx], + prompt = prompt, + rank_prompt = rank_prompt, + block_weight = block_weight[:, idx - 6], + **kwargs + ) + else: + x = blk( + x, + adapt = adapt[idx], + prompt = prompt, + rank_prompt = rank_prompt, + block_weight = None, + **kwargs + ) + + return x + +# ViT from CLIP +class VisualTransformer(nn.Module): + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int = 3, + width: int = 768, + depth: int = 12, + heads: int = 8, + output_dim: int = 512, + text_or_image: str = None, + **kwargs + ): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.in_chans = in_chans + self.width = width + self.depth = depth + self.heads = heads + self.output_dim = output_dim + + self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((img_size // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, depth, heads, text_or_image=text_or_image, **kwargs) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor, **kwargs): + + x = self.conv1(x) + x = x.reshape(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND / [Batch_Size, Seq_len, Dim] -> [Seq_len, Batch_Size, Dim] + x = self.transformer(x, **kwargs) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + +# Standard ViT +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + attn_layer=MultiHeadAttention, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + representation_size=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + ckpt_layer=0, + transformer_layer=Transformer, + **kwargs): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_heads = num_heads + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + if transformer_layer == 'Transformer_Proj': + self.transformer = Transformer_Proj(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) + elif transformer_layer == 'Transformer_CL_LoRA': + self.transformer = Transformer_CL_LoRA(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) + else: + self.transformer = Transformer(embed_dim, depth, num_heads, text_or_image='image', attn_layer=attn_layer, norm_layer=norm_layer, **kwargs) + self.norm = partial(nn.LayerNorm, eps=1e-6)(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs): + + B = x.shape[0] + x = self.patch_embed(x) + + if prompt_flag == 'l2p': + + batched_prompt = None + e_prompt_layer_idx = [] + if prompt: + + num_prompted_layers = 1 + e_prompt_layer_idx = [0] + total_prompt_len = prompt.length * prompt.top_k * len(e_prompt_layer_idx) + + batched_prompt, reduce_sim = prompt(x, cls_features=cls_features) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:, :x.size(1), :] + x = self.pos_drop(x) + + x = x.permute(1, 0, 2) # (B, N ,C) -> (N, B ,C) + x = self.transformer( + x, + l2p_prompt = batched_prompt, + l2p_e_prompt_layer_idx = e_prompt_layer_idx, + **kwargs + ) + x = x.permute(1, 0, 2) # (N, B ,C) -> (B, N ,C) + + x = self.norm(x) + + if prompt: + x = x[:, :total_prompt_len] + x = x.mean(dim=1) + return x, reduce_sim + else: + return x[:, 0] + + else: + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + # TODO: clean, move everything to trasnformer + prompt_loss = torch.zeros((1,), requires_grad=True).to(x.device) + if prompt is not None: + for i,blk in enumerate(self.transformer.blocks): + + if prompt is not None: + if train: + p_list, loss, x = prompt.forward(q, i, x, train=True, task_id=task_id) + prompt_loss += loss + else: + p_list, _, x = prompt.forward(q, i, x, train=False, task_id=task_id) + else: + p_list = None + + # the blk only takes x in shape [N, B, C] not [B, N ,C] + x = x.permute(1, 0, 2) + x = blk(x, register_hook=register_blk==i, prompt=p_list, **kwargs) + x = x.permute(1, 0, 2) + else: + + x = x.permute(1, 0, 2) + x = self.transformer(x, **kwargs) + x = x.permute(1, 0, 2) + + x = self.norm(x) + return x, prompt_loss + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + +class VisionTransformer_CL_LoRA(VisionTransformer): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + + class Adapter_lora(nn.Module): + def __init__(self, + config=None, + d_model=None, + bottleneck=None, + dropout=0.0, + init_option="bert", + adapter_scalar="1.0", + adapter_layernorm_option="in"): + super().__init__() + + self.n_embd = config.d_model if d_model is None else d_model + self.down_size = config.attn_bn if bottleneck is None else bottleneck + + self.lora_A = nn.Linear(self.down_size, self.n_embd, bias=False) + self.lora_B = nn.Linear(self.n_embd, self.down_size, bias=False) + + random_matrix = torch.rand(self.n_embd, self.down_size) + q, r = torch.linalg.qr(random_matrix) + with torch.no_grad(): + self.lora_B.weight.copy_(q.T) + scaling_factor = 1. # You can adjust this value if needed + self.lora_B.weight.data *= scaling_factor + + if init_option == "bert": + raise NotImplementedError + elif init_option == "lora": + with torch.no_grad(): + nn.init.zeros_(self.lora_A.weight) + else: + raise NotImplementedError + + def forward(self, x): + inter_x = self.lora_B(x) + out = self.lora_A(inter_x) + return out + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + attn_layer=MultiHeadAttention, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + representation_size=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=nn.LayerNorm, + ckpt_layer=0, + transformer_layer=Transformer, + **kwargs): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + num_classes=num_classes, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + attn_layer=attn_layer, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + representation_size=representation_size, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_layer=norm_layer, + ckpt_layer=ckpt_layer, + transformer_layer=transformer_layer, + **kwargs + ) + + + cfg_dict = { + 'use_distillation': True, + 'use_block_weight': True, + 'msa_adapt': True, + 'msa': [1, 0, 1], + 'specfic_pos': [6, 7, 8, 9, 10, 11], + 'general_pos': [0, 1, 2, 3, 4, 5], + 'ffn_adapt': True, + 'ffn_option': 'parallel', + 'ffn_adapter_layernorm_option': 'none', + 'ffn_adapter_init_option': 'lora', + 'ffn_adapter_scalar': '0.1', + 'ffn_num': kwargs['lora_rank'], + 'd_model': 768, + 'vpt_on': False, + 'vpt_num': 0, + '_device': 'cuda:0' + } + + from types import SimpleNamespace + + self.tuning_config = SimpleNamespace(**cfg_dict) + self.config = self.tuning_config + + self._device = self.tuning_config._device + self.msa_adapt = self.tuning_config.msa_adapt + self.use_distillation = self.tuning_config.use_distillation + self.use_block_weight = self.tuning_config.use_block_weight + + self.general_pos = self.tuning_config.general_pos + self.specfic_pos = self.tuning_config.specfic_pos + self.adapt_pos = self.general_pos + self.specfic_pos + self.adapt_pos = sorted(self.adapt_pos) + + if self.msa_adapt: + self.msa = self.tuning_config.msa + + if self.use_distillation: + self.old_adapter_list = nn.ModuleList() + + if self.use_block_weight: + self.block_weight_list = [] + self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos))) + nn.init.uniform_(self.block_weight, .5, 1.5) + + self.adapter_list = [] + self.adapter_pos_list = [] + self.cur_adapter = nn.ModuleList() + self.get_new_adapter_initial_msa() + + def forward(self, x, test = False, register_blk=-1, prompt=None, prompt_flag='', q=None, train=False, task_id=-1, cls_features=None, **kwargs): + + if not test: + output = self.forward_train(x) + output = output[:, 0] + return output, None # [bs, 768] + + else: + features = self.forward_test(x) + output = torch.Tensor().to(features[0].device) + for x in features: + cls = x[:, 0, :] + output = torch.cat(( + output, + cls + ), dim=1) + return output, None # [bs, task_id * 768] + + def forward_train(self, x): + + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + x = x.permute(1, 0, 2) + + x = self.transformer( + x, + adapt = self.cur_adapter, + prompt = None, + rank_prompt = None, + block_weight=self.block_weight) + x = x.permute(1, 0, 2) + x = self.norm(x) + + return x + + def forward_test(self, x, use_init_ptm=False): + import copy + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x_init = self.pos_drop(x) + + features = [] + assert self.config.ffn_adapt + assert self.adapt_pos == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + assert self.general_pos == [0, 1, 2, 3, 4, 5] + assert self.use_block_weight + + # len(self.adapter_list) == cur_task_id + + for i in range(len(self.adapter_list)): + x = copy.deepcopy(x_init) + + x = x.permute(1, 0, 2) + for idx, blk in enumerate(self.transformer.blocks): + + if idx >= 6: + x = blk(x, adapt = self.adapter_list[i][idx - 6], prompt = None, rank_prompt = None, + block_weight=self.block_weight_list[i][:, idx - 6]) + else: + x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) + x = x.permute(1, 0, 2) + + x = self.norm(x) + features.append(x) + + x = copy.deepcopy(x_init) + x = x.permute(1, 0, 2) + for idx, blk in enumerate(self.transformer.blocks): + + if idx >= 6: + x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, + block_weight=self.block_weight[:, idx - 6]) + else: + x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) + x = x.permute(1, 0, 2) + + + x = self.norm(x) + features.append(x) + + return features + + def forward_proto(self, x, adapt_index): + assert adapt_index > -1 + assert self.config.ffn_adapt + assert self.use_block_weight + + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + + if adapt_index < len(self.adapter_list): + + x = x.permute(1, 0, 2) + for idx, blk in enumerate(self.transformer.blocks): + + if idx >= 6: + x = blk(x, adapt = self.adapter_list[adapt_index][idx - 6], prompt = None, rank_prompt = None, + block_weight=self.block_weight_list[adapt_index][:, idx - 6]) + else: + x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) + x = x.permute(1, 0, 2) + + else: + + x = x.permute(1, 0, 2) + for idx, blk in enumerate(self.transformer.blocks): + + if idx >= 6: + x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, + block_weight=self.block_weight[:, idx - 6]) + else: + x = blk(x, adapt = self.cur_adapter[idx], prompt = None, rank_prompt = None, block_weight=None) + x = x.permute(1, 0, 2) + + x = self.norm(x) + x = x[:, 0, :] + + return x + + def forward_general_cls(self, x, t_idx): + import copy + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + x_teacher = copy.deepcopy(x) + + for j in range(6): # [0, ..., 5] + x = self.transformer.blocks[j](x, adapt = self.cur_adapter[j]) + x_teacher = self.transformer.blocks[j](x_teacher, adapt = self.old_adapter_list[t_idx-1][j]) + + x = self.norm(x) + output_new = x[:, 0, :] + + x_teacher = self.norm(x_teacher) + output_teacher= x_teacher[:, 0, :] + + return output_new, output_teacher + + def get_new_adapter_initial_msa(self): + + config = self.config + if config.ffn_adapt: + for i in range(len(self.adapt_pos)): + temp_adapter = nn.ModuleList() + for j in self.msa: + if j ==1: + adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num, + init_option=config.ffn_adapter_init_option, + adapter_scalar=config.ffn_adapter_scalar, + adapter_layernorm_option=config.ffn_adapter_layernorm_option, + ).to(self._device) + else: + adapter = nn.Identity() + temp_adapter.append(adapter) + + self.cur_adapter.append(temp_adapter) + self.cur_adapter.requires_grad_(True) + + else: + print("====Not use adapter===") + + def add_adapter_to_list(self): + temp_adapter = [] + import copy + for i in range(len(self.specfic_pos)): + temp_pos = self.adapt_pos.index(self.specfic_pos[i]) + temp_adapter.append(copy.deepcopy(self.cur_adapter[temp_pos].requires_grad_(False))) + self.adapter_list.append(temp_adapter) + + if self.use_block_weight: + self.block_weight_old = copy.deepcopy(self.block_weight) + self.block_weight_list.append(self.block_weight_old.requires_grad_(False)) + self.block_weight = nn.Parameter(torch.randn(3, len(self.specfic_pos))) + nn.init.uniform_(self.block_weight, .5, 1.5) + + self.adapter_pos_list.append(self.adapt_pos) + + if self.use_distillation: + self.old_adapter_list.append(copy.deepcopy(self.cur_adapter).requires_grad_(False)) + if self.msa_adapt: + self.get_new_adapter_msa() + + def get_new_adapter_msa(self): + config = self.config + + if config.ffn_adapt: + for i in range(len(self.specfic_pos)): + pos = self.adapt_pos.index(self.specfic_pos[i]) + temp_adapter = nn.ModuleList() + for j in self.msa: + if j == 1: + adapter = VisionTransformer_CL_LoRA.Adapter_lora(self.config, dropout=0.0, bottleneck=config.ffn_num, + init_option=config.ffn_adapter_init_option, + adapter_scalar=config.ffn_adapter_scalar, + adapter_layernorm_option=config.ffn_adapter_layernorm_option, + ).to(self._device) + adapter.requires_grad_(True) + else: + adapter = nn.Identity() + temp_adapter.append(adapter) + self.cur_adapter[pos] = temp_adapter + + if len(self.specfic_pos) < 12: + self.cur_adapter.requires_grad_(True) + + for i in self.adapt_pos: + if i in self.general_pos: + pos = self.adapt_pos.index(i) + for j in range(len(self.msa)): + if self.msa[j] == 1: + self.cur_adapter[pos][j].lora_B.requires_grad_(False) + else: + print("====Not use adapter===") + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) diff --git a/core/model/backbone/vit.py b/core/model/backbone/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..41eb0fddd08fd220405ad5b7dd96c5bf1c8bb7dd --- /dev/null +++ b/core/model/backbone/vit.py @@ -0,0 +1,305 @@ +''' +Code Reference: +Adapted from https://github.com/GT-RIPL/CODA-Prompt +''' + +import os +import timm +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv +from .prompt import L2P, CodaPrompt, DualPrompt +from .transformer import MultiHeadAttention_LoRA, VisionTransformer, VisionTransformer_CL_LoRA + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size!=new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint + +class ViTZoo(nn.Module): + def __init__(self, pretrained = False, model_name='vit_base_patch16_224', attn_layer='MultiHeadAttention', **kwargs): + super(ViTZoo, self).__init__() + + self.task_id = None + self.feat_dim = 768 + + self.feat = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12, + num_heads=12, ckpt_layer=0, + drop_path_rate=0, attn_layer=attn_layer, + **kwargs + ) + + if pretrained: + print(f'Using pretrained model : {model_name}') + + if model_name == 'vit_base_patch16_224.augreg2_in21k_ft_in1k' and os.path.exists('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt'): + # Manually Loading weight + load_dict = torch.load('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt') + else: + load_dict = timm.create_model(model_name, pretrained = pretrained).state_dict() + + key_mapping = { + ".norm1.": ".ln_1.", + ".norm2.": ".ln_2.", + "blocks.": "transformer.blocks." + } + + modified_load_dict = {} + for key in load_dict.keys(): + new_key = key + for old_key, mapped_key in key_mapping.items(): + if old_key in new_key: + new_key = new_key.replace(old_key, mapped_key) + + modified_load_dict[new_key] = load_dict[key] + + self.feat.load_state_dict(modified_load_dict, strict = False) + + self.prompt = None + self.prompt_flag = '' + + def create_prompt(self, prompt_flag, **kwargs): + self.prompt_flag = prompt_flag + + if self.prompt_flag == 'l2p': + self.prompt = L2P(**kwargs) + elif self.prompt_flag == 'dual': + self.prompt = DualPrompt(768, **kwargs) + elif self.prompt_flag == 'coda': + self.prompt = CodaPrompt(768, **kwargs) + + # pen: get penultimate features + def forward(self, image, text=None, pen=False, train=False, **kwargs): + + if self.prompt_flag == 'l2p': + + with torch.no_grad(): + self.eval() + cls_features = self.feat(image, prompt_flag = self.prompt_flag) + + if train: + self.train() + + out, reduce_sim = self.feat( + x = image, + prompt = self.prompt, + cls_features = cls_features, + prompt_flag = self.prompt_flag + ) + + return out, reduce_sim + + elif self.prompt is not None: + with torch.no_grad(): + q, _ = self.feat(image) + q = q[:,0,:] + + # q?, train?, task_id? + out, prompt_loss = self.feat(image, prompt=self.prompt, q=q, train=train, task_id=self.task_id) + out = out[:,0,:] + else: + out, _ = self.feat(image, **kwargs) + if len(out.shape) == 3: + out = out[:,0,:] + + out = out.view(out.size(0), -1) + + if self.prompt is not None and train: + return out, prompt_loss + else: + return out + +class ViT_in21k_adapter(nn.Module): + def __init__(self, pretrained=False, **kwargs): + super(ViT_in21k_adapter, self).__init__() + + self.task_id = None + self.feat_dim = 768 + # get feature encoder + if pretrained: + print("Using pretrained model") + from core.model.backbone.petl import vision_transformer_adapter + from easydict import EasyDict + + tuning_config = EasyDict( + # AdaptFormer + ffn_adapt=True, + ffn_option="parallel", + ffn_adapter_layernorm_option="none", + ffn_adapter_init_option="lora", + ffn_adapter_scalar="0.1", + ffn_num=64, + d_model=768, + # VPT related + vpt_on=False, + vpt_num=0, + ) + + zoo_model = vision_transformer_adapter.vit_base_patch16_224_in21k_adapter(num_classes=0, + global_pool=False, drop_path_rate=0.0, tuning_config=tuning_config) + zoo_model.out_dim=768 + zoo_model.eval() + + self.prompt = None + + # feature encoder changes if transformer vs resnet + self.feat = zoo_model + + def create_prompt(self, prompt_flag, **kwargs): + self.prompt_flag = prompt_flag + # self.prompt_param = prompt_param + # create prompting module + if self.prompt_flag == 'l2p': + self.prompt = L2P(768, **kwargs) + elif self.prompt_flag == 'dual': + self.prompt = DualPrompt(768, **kwargs) + elif self.prompt_flag == 'coda': + self.prompt = CodaPrompt(768, **kwargs) + + # pen: get penultimate features + def forward(self, x, pen=False, train=False): + if self.prompt is not None: + with torch.no_grad(): + q, _ = self.feat(x) + q = q[:,0,:] + out, prompt_loss = self.feat(x, prompt=self.prompt, q=q, train=train, task_id=self.task_id) + out = out[:,0,:] + else: + out = self.feat(x) # This implementation of adapter vit doesn't return prompt loss + + out = out.view(out.size(0), -1) + # if not pen: + # out = self.last(out) + if self.prompt is not None and train: + return out, prompt_loss + else: + return out + +class ViT_CL_LoRA(nn.Module): + def __init__(self, pretrained = False, model_name='vit_base_patch16_224', attn_layer='MultiHeadAttention', **kwargs): + super().__init__() + + self.task_id = None + self.feat_dim = 768 + + self.feat = VisionTransformer_CL_LoRA(img_size=224, patch_size=16, embed_dim=768, depth=12, + num_heads=12, ckpt_layer=0, + drop_path_rate=0, attn_layer=attn_layer, + **kwargs + ) + + if pretrained: + print(f'Using pretrained model : {model_name}') + + if model_name == 'vit_base_patch16_224.augreg2_in21k_ft_in1k' and os.path.exists('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt'): + # Manually Loading weight + load_dict = torch.load('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt') + else: + load_dict = timm.create_model(model_name, pretrained = pretrained).state_dict() + + key_mapping = { + ".norm1.": ".ln_1.", + ".norm2.": ".ln_2.", + "blocks.": "transformer.blocks." + } + + modified_load_dict = {} + for key in load_dict.keys(): + new_key = key + for old_key, mapped_key in key_mapping.items(): + if old_key in new_key: + new_key = new_key.replace(old_key, mapped_key) + + modified_load_dict[new_key] = load_dict[key] + + self.feat.load_state_dict(modified_load_dict, strict = False) + + self.prompt = None + self.prompt_flag = '' + + # pen: get penultimate features + def forward(self, image, test, text=None, pen=False, train=False, **kwargs): + + if self.prompt_flag == 'l2p': + + with torch.no_grad(): + self.eval() + cls_features = self.feat(image, prompt_flag = self.prompt_flag) + + if train: + self.train() + + out, reduce_sim = self.feat( + x = image, + prompt = self.prompt, + cls_features = cls_features, + prompt_flag = self.prompt_flag + ) + + return out, reduce_sim + + elif self.prompt is not None: + with torch.no_grad(): + q, _ = self.feat(image) + q = q[:,0,:] + + # q?, train?, task_id? + out, prompt_loss = self.feat(image, prompt=self.prompt, q=q, train=train, task_id=self.task_id) + out = out[:,0,:] + else: + out, _ = self.feat(image, test, **kwargs) + if len(out.shape) == 3: + out = out[:,0,:] + + out = out.view(out.size(0), -1) + + if self.prompt is not None and train: + return out, prompt_loss + else: + return out + + def forward_proto(self, x, adapt_index): + return self.feat.forward_proto(x, adapt_index) + + def forward_general_cls(self, x, t_idx): + return self.feat.forward_general_cls(x, t_idx) + + def add_adapter_to_list(self): + self.feat.add_adapter_to_list() + +def vit_pt_imnet(pretrained=False, **kwargs): + return ViTZoo(pretrained, **kwargs) + +def vit_pt_imnet_in21k_adapter(pretrained=False, **kwargs): + return ViT_in21k_adapter(pretrained, **kwargs) + +def vit_cl_lora(pretrained=False, **kwargs): + return ViT_CL_LoRA(pretrained, **kwargs) \ No newline at end of file diff --git a/core/model/backbone/vit_dap.py b/core/model/backbone/vit_dap.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c967f04dc627384031bef21d0500db50f826c3 --- /dev/null +++ b/core/model/backbone/vit_dap.py @@ -0,0 +1,1170 @@ +""" Vision Transformer (ViT) in PyTorch +A PyTorch implement of Vision Transformers as described in: +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 +The official jax code is released and available at https://github.com/google-research/vision_transformer +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert +Hacked together by / Copyright 2020, Ross Wightman +# ------------------------------------------ +# Modification: +# Added code for l2p implementation +# -- Jaeho Lee, dlwogh9344@khu.ac.kr +# ------------------------------------------ +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.models import create_model +from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq +from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from timm.models.registry import register_model + +from .prompt import DAP + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'vit_tiny_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_tiny_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_base_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + # 'vit_base_patch16_224': _cfg( + # url='https://storage.googleapis.com/vit_models/augreg/' + # 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + + 'vit_large_patch14_224': _cfg(url=''), + 'vit_huge_patch14_224': _cfg(url=''), + 'vit_giant_patch14_224': _cfg(url=''), + 'vit_gigantic_patch14_224': _cfg(url=''), + + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), + 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub_id='timm/vit_huge_patch14_224_in21k', + num_classes=21843), + + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_224_sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), + 'vit_base_patch16_224_sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), + + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + 'vit_small_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_small_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + 'vit_base_patch16_224_miil': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' + '/vit_base_patch16_224_1k_miil_84_4.pth', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', + ), + + 'vit_base_patch16_rpn_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), + + # experimental (may be removed) + 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), + 'vit_small_patch16_36x1_224': _cfg(url=''), + 'vit_small_patch16_18x2_224': _cfg(url=''), + 'vit_base_patch16_18x2_224': _cfg(url=''), +} + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x): + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class ParallelBlock(nn.Module): + + def __init__( + self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + self.ffns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + + def _forward_jit(self, x): + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x): + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, + class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, + prompt_length=None, embedding_key='cls', prompt_init='uniform', prompt_pool=False, prompt_key=False, pool_size=None, + top_k=None, batchwise_prompt=False, prompt_key_init='uniform', head_type='token', use_prompt_mask=False,): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + block_fn: (nn.Module): transformer block + prompt_pool (bool): use prompt pool or not + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.img_size = img_size + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.class_token = class_token + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + self.grad_checkpointing = False + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + if prompt_length is not None and pool_size is not None and prompt_pool: + embed_len += prompt_length * top_k + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + self.prompt_pool = prompt_pool + self.head_type = head_type + self.use_prompt_mask = use_prompt_mask + + if prompt_length is not None and pool_size is not None and prompt_pool: + self.prompt = DAP(length=prompt_length, embed_dim=embed_dim, embedding_key=embedding_key, prompt_init=prompt_init, + prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k, batchwise_prompt=batchwise_prompt, + prompt_key_init=prompt_key_init,tasklength=10) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, task_id=-1, cls_features=None, train=False,gen=False): + x = self.patch_embed(x) + + if hasattr(self, 'prompt'): + if self.use_prompt_mask and train: + start = task_id * self.prompt.top_k + end = (task_id + 1) * self.prompt.top_k + single_prompt_mask = torch.arange(start, end).to(x.device) + prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1) + if end > self.prompt.pool_size: + prompt_mask = None + else: + prompt_mask = None + res = self.prompt(x, prompt_mask=prompt_mask, cls_features=cls_features,taskid=task_id) + self.total_prompt_len = res['total_prompt_len'] + if gen: + + x = res['gen_prompted_embedding'] + else: + x = res['prompted_embedding'] + + + else: + res=dict() + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + x = self.pos_drop(x + self.pos_embed) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + x = self.norm(x) + res['x'] = x + + return res + + def forward_head(self, res, pre_logits: bool = False): + # x:[batchsize,token_len,embed_dim] + x = res['x'] + if self.class_token and self.head_type == 'token': + + x = x[:, 0] + elif self.head_type == 'gap' and self.global_pool == 'avg': + x = x.mean(dim=1) + elif self.head_type == 'prompt' and self.prompt_pool: + x = x[:, 1:(1 + self.total_prompt_len)] if self.class_token else x[:, 0:self.total_prompt_len] + x = x.mean(dim=1) + elif self.head_type == 'token+prompt' and self.prompt_pool and self.class_token: + x = x[:, 0:self.total_prompt_len + 1] + x = x.mean(dim=1) + else: + raise ValueError(f'Invalid classifier={self.classifier}') + + res['pre_logits'] = x + + x = self.fc_norm(x) + + res['logits'] = self.head(x) + + return res + + def forward(self, x, task_id=-1, cls_features=None, train=False, gen=False): + res = self.forward_features(x, task_id, cls_features=cls_features, train=train, gen=gen) + res = self.forward_head(res) + return res + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): + """ ViT weight initialization, matching JAX (Flax) impl """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_moco(module: nn.Module, name: str = ''): + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_init_weights_vit(mode='jax', head_bias: float = 0.): + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + # modify + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + # ntok_new -= num_prefix_tokens + else: + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if ntok_new > gs_old ** 2: + ntok_new -= gs_old ** 2 + # expand cls's pos embedding for prompt tokens + posemb_prefix = posemb_prefix.expand(-1, ntok_new, -1) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, + model.pos_embed, + 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + elif adapt_layer_scale and 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in pretrained_cfg['url'], + **kwargs) + return model + + +@register_model +def vit_tiny_patch16_224_dap(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_384_dap(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) @ 384x384. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_384_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_384_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_384_dap(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_384_dap(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch8_224_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224_dap(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_384_dap(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224_dap(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_384_dap(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch14_224_dap(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/14) + """ + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_dap(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + """ + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_giant_patch14_224_dap(pretrained=False, **kwargs): + """ ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_gigantic_patch14_224_dap(pretrained=False, **kwargs): + """ ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch8_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_in21k_dap(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_sam_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_sam_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_dino_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch8_224_dino_dap(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_dino_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch8_224_dino_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil_in21k_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + return model + + +# Experimental models below + +@register_model +def vit_base_patch32_plus_256_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32+) + """ + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_plus_240_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16+) + """ + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_rpn_224_dap(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ residual post-norm + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, + block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) + model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_36x1_224_dap(pretrained=False, **kwargs): + """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_18x2_224_dap(pretrained=False, **kwargs): + """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) + model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_18x2_224_dap(pretrained=False, **kwargs): + """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) + model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + return model + + +def vit_pt_imnet_dap(pretrained=False, model_name='vit_base_patch16_224', **kwargs): + return create_model( + model_name, + pretrained=pretrained, + num_classes=kwargs['num_classes'], + drop_rate=kwargs['drop'], + drop_path_rate=kwargs['drop_path'], + drop_block_rate=None, + prompt_length=kwargs['length'], + embedding_key=kwargs['embedding_key'], + prompt_init=kwargs['prompt_key_init'], + prompt_pool=kwargs['prompt_pool'], + prompt_key=kwargs['prompt_key'], + pool_size=kwargs['size'], + top_k=kwargs['top_k'], + batchwise_prompt=kwargs['batchwise_prompt'], + prompt_key_init=kwargs['prompt_key_init'], + head_type=kwargs['head_type'], + use_prompt_mask=kwargs['use_prompt_mask'], + ) \ No newline at end of file diff --git a/core/model/backbone/vit_inflora.py b/core/model/backbone/vit_inflora.py new file mode 100644 index 0000000000000000000000000000000000000000..76d1c290e432f898ee9b40903da6d12dab6782a5 --- /dev/null +++ b/core/model/backbone/vit_inflora.py @@ -0,0 +1,720 @@ +""" Vision Transformer (ViT) in PyTorch +A PyTorch implement of Vision Transformers as described in: +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 +The official jax code is released and available at https://github.com/google-research/vision_transformer +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert +Hacked together by / Copyright 2020, Ross Wightman +""" + +import math +import logging +from functools import partial +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq +from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from timm.models.registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'vit_tiny_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_tiny_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_base_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + + 'vit_large_patch14_224': _cfg(url=''), + 'vit_huge_patch14_224': _cfg(url=''), + 'vit_giant_patch14_224': _cfg(url=''), + 'vit_gigantic_patch14_224': _cfg(url=''), + + 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), + 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub_id='timm/vit_huge_patch14_224_in21k', + num_classes=21843), + + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_224_sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), + 'vit_base_patch16_224_sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), + + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + 'vit_small_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_small_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch16_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch8_224_dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + 'vit_base_patch16_224_miil': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' + '/vit_base_patch16_224_1k_miil_84_4.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), + + # experimental + 'vit_small_patch16_36x1_224': _cfg(url=''), + 'vit_small_patch16_18x2_224': _cfg(url=''), + 'vit_base_patch16_18x2_224': _cfg(url=''), +} + + +class Attention_LoRA(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., r=64, n_tasks=10): + super().__init__() + + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + self.rank = r + + self.lora_A_k = nn.ModuleList([nn.Linear(dim, r, bias=False) for _ in range(n_tasks)]) + self.lora_B_k = nn.ModuleList([nn.Linear(r, dim, bias=False) for _ in range(n_tasks)]) + self.lora_A_v = nn.ModuleList([nn.Linear(dim, r, bias=False) for _ in range(n_tasks)]) + self.lora_B_v = nn.ModuleList([nn.Linear(r, dim, bias=False) for _ in range(n_tasks)]) + self.rank = r + + self.matrix = torch.zeros(dim ,dim) + self.n_matrix = 0 + self.cur_matrix = torch.zeros(dim ,dim) + self.n_cur_matrix = 0 + + def init_param(self): + for t in range(len(self.lora_A_k)): + nn.init.kaiming_uniform_(self.lora_A_k[t].weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A_v[t].weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B_k[t].weight) + nn.init.zeros_(self.lora_B_v[t].weight) + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, task, register_hook=False, get_feat=False,get_cur_feat=False): + if get_feat: + self.matrix = (self.matrix*self.n_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_matrix + x.shape[0]*x.shape[1]) + self.n_matrix += x.shape[0]*x.shape[1] + if get_cur_feat: + self.cur_matrix = (self.cur_matrix*self.n_cur_matrix + torch.bmm(x.detach().permute(0, 2, 1), x.detach()).sum(dim=0).cpu())/(self.n_cur_matrix + x.shape[0]*x.shape[1]) + self.n_cur_matrix += x.shape[0]*x.shape[1] + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + # insert lora + if task > -0.5: + weight_k = torch.stack([torch.mm(self.lora_B_k[t].weight, self.lora_A_k[t].weight) for t in range(task+1)], dim=0).sum(dim=0) + weight_v = torch.stack([torch.mm(self.lora_B_v[t].weight, self.lora_A_v[t].weight) for t in range(task+1)], dim=0).sum(dim=0) + k = k + F.linear(x, weight_k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = v + F.linear(x, weight_v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def get_matrix(self, task): + matrix_k = torch.mm(self.lora_B_k[task].weight, self.lora_A_k[task].weight) + matrix_v = torch.mm(self.lora_B_v[task].weight, self.lora_A_v[task].weight) + return matrix_k, matrix_v + + def get_pre_matrix(self, task): + with torch.no_grad(): + weight_k = torch.stack([torch.mm(self.lora_B_k[t].weight, self.lora_A_k[t].weight) for t in range(task)], dim=0).sum(dim=0) + weight_v = torch.stack([torch.mm(self.lora_B_v[t].weight, self.lora_A_v[t].weight) for t in range(task)], dim=0).sum(dim=0) + return weight_k, weight_v + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, n_tasks=10, r=64): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_LoRA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, n_tasks=n_tasks, r=r) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, task, register_hook=False, get_feat=False, get_cur_feat=False): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), task, register_hook=register_hook, get_feat=get_feat, get_cur_feat=get_cur_feat))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ParallelBlock(nn.Module): + + def __init__( + self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('attn', Attention_LoRA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + self.ffns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + + def _forward_jit(self, x): + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x): + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, n_tasks=10, rank=64): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init: (str): weight init scheme + init_values: (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.grad_checkpointing = False + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.cls_token_grow = nn.Parameter(torch.zeros(1, 5000, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_embed_grow = nn.Parameter(torch.zeros(1, num_patches + 1000, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,n_tasks=n_tasks,r=rank) + for i in range(depth)]) + use_fc_norm = self.global_pool == 'avg' + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Representation layer. Used for original ViT models w/ in21k pretraining. + self.representation_size = representation_size + self.pre_logits = nn.Identity() + if representation_size: + self._reset_representation(representation_size) + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + final_chs = self.representation_size if self.representation_size else self.embed_dim + self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + self.out_dim = final_chs + + if weight_init != 'skip': + self.init_weights(weight_init) + + def _reset_representation(self, representation_size): + self.representation_size = representation_size + if self.representation_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(self.embed_dim, self.representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.pos_embed_grow, std=.02) + nn.init.normal_(self.cls_token, std=1e-6) + nn.init.normal_(self.cls_token_grow, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + if representation_size is not None: + self._reset_representation(representation_size) + final_chs = self.representation_size if self.representation_size else self.embed_dim + self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + x = self.pos_drop(x + self.pos_embed) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_features_grow(self, x, class_num): + x = self.patch_embed(x) + # x = torch.cat((self.cls_token_grow[:, :class_num, :].expand(x.shape[0], -1, -1), self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + # x = self.pos_drop(x + self.pos_embed_grow[:, :self.patch_embed.num_patches+class_num, :]) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = torch.cat((self.cls_token_grow[:, :class_num*2, :].expand(x.shape[0], -1, -1), x), dim=1) + + # import pdb;pdb.set_trace() + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + x = self.pre_logits(x) + return x if pre_logits else self.head(x) + + def forward(self, x, grow_flag=False, numcls=0): + if not grow_flag: + x = self.forward_features(x) + else: + x = self.forward_features_grow(x, numcls) + + if self.global_pool: + x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return { + 'fmaps': [x], + 'features': x + } + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): + """ ViT weight initialization, matching JAX (Flax) impl """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def init_weights_vit_moco(module: nn.Module, name: str = ''): + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def get_init_weights_vit(mode='jax', head_bias: float = 0.): + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + # NOTE this extra code to support handling of repr size for in21k pretrained models + # pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant) + default_num_classes = pretrained_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + if pretrained_cfg: + del kwargs['pretrained_cfg'] + + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + pretrained_cfg=pretrained_cfg, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in pretrained_cfg['url'], + **kwargs) + return model \ No newline at end of file diff --git a/core/model/bic.py b/core/model/bic.py new file mode 100644 index 0000000000000000000000000000000000000000..26701e27b9d2a937ced0d3a7490a35099e4770bb --- /dev/null +++ b/core/model/bic.py @@ -0,0 +1,897 @@ +""" +@inproceedings{wu2019large, + title={Large Scale Incremental Learning}, + author={Wu, Yue and Chen, Yinpeng and Wang, Lijuan and Ye, Yuancheng and Liu, Zicheng and Guo, Yandong and Fu, Yun}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={374--382}, + year={2019} +} +https://arxiv.org/abs/1905.13260 + +Adapted from https://github.com/wuyuebupt/LargeScaleIncrementalLearning and https://github.com/sairin1202/BIC. +""" + +import torch +import torch.optim as optim +import torch.nn.functional as F +import numpy as np + +from torch import nn +from copy import deepcopy +from torch.utils.data import DataLoader +from core.model.backbone.resnet import BiasLayer +from collections import Counter + +# Spilt images and labels into train_dataset and test_dataset, it is assured that each label's count are balanced in both output dataset +def classwise_spilt(images, labels, test_size, random_state=None): + + images, labels = np.array(images), np.array(labels) + classes = np.unique(labels) + + selected_train_images, selected_train_labels = [], [] + selected_val_images, selected_val_labels = [], [] + + for class_id in classes: + idx = np.where(labels == class_id)[0] + + if random_state is not None: + rng = np.random.RandomState(random_state + int(class_id)) # ensure different seed per class + rng.shuffle(idx) + else: + np.random.shuffle(idx) + + cls_images = images[idx] + cls_labels = labels[idx] + + split_idx = int(len(idx) * (1 - test_size)) + if split_idx == 0 and len(idx) > 1: + print(f"[WARNING] Class {class_id} has only {len(idx)} samples, forced 1 into val set.") + split_idx = 1 + + selected_train_images.extend(cls_images[:split_idx]) + selected_train_labels.extend(cls_labels[:split_idx]) + + selected_val_images.extend(cls_images[split_idx:]) + selected_val_labels.extend(cls_labels[split_idx:]) + + return selected_train_images, selected_val_images, selected_train_labels, selected_val_labels + +# Simply spilt the dataset by slicing, the balance of label is not assured +def slice_spilt(images, labels, test_size): + + total = len(labels) + + train_count = int(total*(1-test_size)) + test_count = int(total*test_size) + + selected_train_images, selected_train_labels = images[:train_count], labels[:train_count] + selected_val_images, selected_val_labels = images[train_count:], labels[train_count:] + + return selected_train_images, selected_val_images, selected_train_labels, selected_val_labels + +class Model(nn.Module): + + def __init__(self, backbone, num_class, device): + super().__init__() + self.backbone = backbone + self.num_class = num_class + self.classifier = nn.Linear(backbone.feat_dim, num_class) + + def forward(self, x): + return self.classifier(self.backbone(x)) + +class bic(nn.Module): + def __init__(self, backbone, num_class, **kwargs): + + super().__init__() + + self.device = kwargs['device'] + self.task_num = kwargs['task_num'] + self.bias_layers = nn.ModuleList([BiasLayer().to(self.device) for _ in range(self.task_num)]) + + params = [] + for layer in self.bias_layers: + params += layer.parameters() + + self.bias_optimizer = optim.Adam(params, lr = 1e-3) + self.model = Model(backbone, num_class, self.device) + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + + self.seen_cls = 0 + self.cur_task = 0 + + self.previous_model = None + self.criterion = nn.CrossEntropyLoss() + + self.cls_count = {} + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self.previous_model = deepcopy(self.model) + + for param in self.previous_model.parameters(): + param.requires_grad_(False) + + for param in self.model.parameters(): + param.requires_grad_(True) + + for layer in self.bias_layers: + [param.requires_grad_(False) for param in layer.parameters()] + + self.cur_task = task_idx + self.seen_cls += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + def bias_forward(self, input, train=True): + + outputs = [] + + train = False + if train: + + for i, layer in enumerate(self.bias_layers): + if i == 0: + input_slice = input[:, :self.init_cls_num] + else: + input_slice = input[:, (i-1) * self.inc_cls_num + self.init_cls_num : i * self.inc_cls_num + self.init_cls_num] + + if i == self.cur_task: + outputs.append(layer(input_slice)) + else: + outputs.append(input_slice) + + else: + + for i, layer in enumerate(self.bias_layers): + if i == 0: + input_slice = input[:, :self.init_cls_num] + else: + input_slice = input[:, (i-1) * self.inc_cls_num + self.init_cls_num : i * self.inc_cls_num + self.init_cls_num] + + outputs.append(layer(input_slice)) + + return torch.cat(outputs, dim=1) + + def inference(self, data): + x, y = data['image'].to(self.device), data['label'].view(-1).to(self.device) + + p = self.model(x) + p = self.bias_forward(p, train = False) + pred = p[:, :self.seen_cls].argmax(dim=-1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0) + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + for param in self.model.parameters(): + param.requires_grad_(False) + + for i, layer in enumerate(self.bias_layers): + + if i == task_idx: + [param.requires_grad_(True) for param in layer.parameters()] + else: + [param.requires_grad_(False) for param in layer.parameters()] + + print(f'Bias Layer {i} : {layer.alpha.item()}, {layer.beta.item()} {layer.alpha.requires_grad}') + + # The classic two-phase processing approach employed by BIC. + def stage1(self, data): + + x, y = data['image'].to(self.device), data['label'].view(-1).to(self.device) + + p = self.model(x) + + p = self.bias_forward(p) + loss = self.criterion(p[:,:self.seen_cls], y) + pred = torch.argmax(p[:,:self.seen_cls], dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0), loss + + def stage1_distill(self, data): + + x, y = data['image'].to(self.device), data['label'].view(-1).to(self.device) + + T = 2 # temperature + alpha = 1.0 * (self.seen_cls - self.inc_cls_num) / self.seen_cls + assert 1.0 * self.cur_task / (self.cur_task + 1) == alpha + + p = self.model(x) + p = self.bias_forward(p) + + pred = torch.argmax(p[:, :self.seen_cls], dim=1) + acc = torch.sum(pred == y).item() + + with torch.no_grad(): + pre_p = self.previous_model(x) + pre_p = self.bias_forward(pre_p, train = True) + pre_p = F.softmax(pre_p[:, :self.seen_cls - self.inc_cls_num]/T, dim=1) + + logp = F.log_softmax(p[:, :self.seen_cls-self.inc_cls_num]/T, dim=1) + loss_soft_target = -torch.mean(torch.sum(pre_p * logp, dim=1)) + loss_hard_target = self.criterion(p[:, :self.seen_cls], y) + loss = alpha * loss_soft_target * T * T + (1-alpha) * loss_hard_target # T**2 stated in 'Distilling the Knowledge in a Neural Network', last paragraph of section 'Distillation' + + return pred, acc / x.size(0), loss + + def stage2(self, data): + + x, y = data['image'].to(self.device), data['label'].view(-1).to(self.device) + p = self.model(x) + p = self.bias_forward(p) + loss = self.criterion(p[:,:self.seen_cls], y) + pred = torch.argmax(p[:,:self.seen_cls], dim=1) + acc = torch.sum(pred == y).item() + + self.bias_optimizer.zero_grad() + loss.backward() + self.bias_optimizer.step() + + return pred, acc / x.size(0), loss + + def observe(self, data): + + if self.cur_task > 0: + return self.stage1_distill(data) + else: + return self.stage1(data) + + def get_parameters(self, config): + + return self.model.parameters() + + def spilt_and_update(self, dataloader, buffer, task_idx, config): + + val_ratio = 0.1 + buffer_size = config['buffer']['kwargs']['buffer_size'] + + train_dataset, val_dataset = deepcopy(dataloader.dataset), deepcopy(dataloader.dataset) + + # Save classes count for classwise buffering + self.cls_count.update(Counter(train_dataset.labels)) + + # Train_loader + images_train, images_val, labels_train, labels_val = classwise_spilt( + train_dataset.images, + train_dataset.labels, + test_size=val_ratio + ) + + train_dataset.images = images_train + buffer.train_images + train_dataset.labels = labels_train + buffer.train_labels + + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + drop_last=True) + + val_dataloader = None + + # Val_loader + if task_idx > 0: + + selected_images_val, selected_labels_val = buffer.val_images.copy(), buffer.val_labels.copy() + + for cls_label in np.unique(labels_val): + cls_idx, = np.where(np.array(labels_val) == cls_label) + cls_images, cls_labels = np.array(images_val)[cls_idx], np.array(labels_val)[cls_idx] + + selected_images_val.extend(cls_images) + selected_labels_val.extend(cls_labels) + + val_dataset.images = selected_images_val + val_dataset.labels = selected_labels_val + + val_dataloader = DataLoader( + val_dataset, + shuffle=True, + batch_size=100, + num_workers=config['num_workers'], + drop_last=False) + + # Update Buffer + buffer.train_images.extend(images_train) + buffer.train_labels.extend(labels_train) + buffer.val_images.extend(images_val) + buffer.val_labels.extend(labels_val) + buffer.total_classes += config['init_cls_num'] if task_idx == 0 else config['inc_cls_num'] + + preserved_images_train, preserved_labels_train = [], [] + preserved_images_val, preserved_labels_val = [], [] + + total_cls_counts = sum(self.cls_count.values()) + + for cls in range(buffer.total_classes): + train_cls_idx = np.where(np.array(buffer.train_labels) == cls) + train_cls_images = np.array(buffer.train_images)[train_cls_idx] + train_cls_labels = np.array(buffer.train_labels)[train_cls_idx] + + val_cls_idx = np.where(np.array(buffer.val_labels) == cls) + val_cls_images = np.array(buffer.val_images)[val_cls_idx] + val_cls_labels = np.array(buffer.val_labels)[val_cls_idx] + + preserved_val = int(self.cls_count[cls] * buffer_size / total_cls_counts * val_ratio) + preserved_train = int(self.cls_count[cls] * buffer_size / total_cls_counts * (1 - val_ratio)) + if preserved_val == 0 and preserved_train > 1: + preserved_val = 1 + preserved_train -= 1 + + print( + f"[Class {cls}] total: {self.cls_count[cls]} | " + f"buffer_train: {len(train_cls_labels)}, keep: {preserved_train} | " + f"buffer_val: {len(val_cls_labels)}, keep: {preserved_val}" + ) + + preserved_images_train.extend(train_cls_images[:preserved_train]) + preserved_labels_train.extend(train_cls_labels[:preserved_train]) + preserved_images_val.extend(val_cls_images[:preserved_val]) + preserved_labels_val.extend(val_cls_labels[:preserved_val]) + + buffer.train_images = preserved_images_train + buffer.train_labels = preserved_labels_train + buffer.val_images = preserved_images_val + buffer.val_labels = preserved_labels_val + print(f'Buffer Usage : {len(buffer.train_labels) + len(buffer.val_labels)}/{buffer_size}') + + return train_dataloader, val_dataloader + + ''' + # split_and_update1 (比例,且多): 将新的训练数据分成 9:1, 按照未更新的 buffer 中的 val data 的数量加入 val data + # split_and_update2 (比例,但少): 将新的训练数据分成 9:1, 按照更新后的 buffer 中的 val data 的数量加入 val data + # split_and_update4 (非比例,且少): 根据未更新的 buffer 中的 val data 的数量分割新的训练数据,然后安排 + + @staticmethod + def spilt_and_update1(dataloader, buffer, task_idx, config): + + print('using spilt_and_update1') + + train_dataset = deepcopy(dataloader.dataset) + val_dataset = deepcopy(dataloader.dataset) + + # Train_loader + images_train, images_val, labels_train, labels_val = slice_spilt( + train_dataset.images, + train_dataset.labels, + test_size=0.1 + ) + + train_dataset.images = images_train + buffer.train_images + train_dataset.labels = labels_train + buffer.train_labels + + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + drop_last=True) + + # Val_loader + if task_idx == 0: + val_dataloader = None + else: + current_num_per_classes_val = min(len(labels_val)//config['inc_cls_num'], (buffer.buffer_size * 0.1) // buffer.total_classes) + print(f'Assigning {current_num_per_classes_val} per class in val data') + + current_num_per_classes_val = int(current_num_per_classes_val) + + selected_images_val, selected_labels_val = [], [] + + for cls_label in np.unique(labels_val): + + cls_idx, = np.where(np.array(labels_val) == cls_label) + cls_images, cls_labels = np.array(images_val)[cls_idx], np.array(labels_val)[cls_idx] + + selected_images_val.extend(cls_images[:current_num_per_classes_val]) + selected_labels_val.extend(cls_labels[:current_num_per_classes_val]) + + print(f'{cls_label}, {len(cls_labels[:current_num_per_classes_val])}/{len(cls_labels)}') + + print(buffer.val_labels) + + for cls_label in range(buffer.total_classes): + cls_idx, = np.where(np.array(buffer.val_labels) == cls_label) + cls_images, cls_labels = np.array(buffer.val_images)[cls_idx], np.array(buffer.val_labels)[cls_idx] + + selected_images_val.extend(cls_images[:current_num_per_classes_val]) + selected_labels_val.extend(cls_labels[:current_num_per_classes_val]) + + print(f'{cls_label}, {len(cls_labels[:current_num_per_classes_val])}/{len(cls_labels)}') + + val_dataset.images = selected_images_val + val_dataset.labels = selected_labels_val + + val_dataloader = DataLoader( + val_dataset, + shuffle=True, + batch_size=100, + num_workers=config['num_workers'], + drop_last=False) + + # Update Buffer + buffer.total_classes += config['init_cls_num'] if task_idx == 0 else config['inc_cls_num'] + new_num_per_classes_train = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.9) + new_num_per_classes_val = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.1) + + preserved_images_train, preserved_labels_train = [], [] + preserved_images_val, preserved_labels_val = [], [] + + # Preserved old in buffer + for old_cls_label in np.unique(buffer.train_labels): + + cls_idx = np.where(buffer.train_labels == old_cls_label) + cls_images_train, cls_labels_train = np.array(buffer.train_images)[cls_idx], np.array(buffer.train_labels)[cls_idx] + + cls_idx = np.where(buffer.val_labels == old_cls_label) + cls_images_val, cls_labels_val = np.array(buffer.val_images)[cls_idx], np.array(buffer.val_labels)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + # Add new into buffer + for new_cls_label in np.unique(labels_train): + + cls_idx = np.where(labels_train == new_cls_label) + cls_images_train, cls_labels_train = np.array(images_train)[cls_idx], np.array(labels_train)[cls_idx] + + cls_idx = np.where(labels_val == new_cls_label) + cls_images_val, cls_labels_val = np.array(images_val)[cls_idx], np.array(labels_val)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + buffer.train_images = preserved_images_train + buffer.train_labels = preserved_labels_train + buffer.val_images = preserved_images_val + buffer.val_labels = preserved_labels_val + + return train_dataloader, val_dataloader + + @staticmethod + def spilt_and_update2(dataloader, buffer, task_idx, config): + + print('using spilt_and_update2') + + train_dataset = deepcopy(dataloader.dataset) + val_dataset = deepcopy(dataloader.dataset) + + # Train_loader + images_train, images_val, labels_train, labels_val = slice_spilt( + train_dataset.images, + train_dataset.labels, + test_size=0.1 + ) + + train_dataset.images = images_train + buffer.train_images + train_dataset.labels = labels_train + buffer.train_labels + + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + drop_last=True) + + # Update Buffer + buffer.total_classes += config['init_cls_num'] if task_idx == 0 else config['inc_cls_num'] + new_num_per_classes_train = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.9) + new_num_per_classes_val = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.1) + + preserved_images_train, preserved_labels_train = [], [] + preserved_images_val, preserved_labels_val = [], [] + + # Preserved old in buffer + for old_cls_label in np.unique(buffer.train_labels): + + cls_idx = np.where(buffer.train_labels == old_cls_label) + cls_images_train, cls_labels_train = np.array(buffer.train_images)[cls_idx], np.array(buffer.train_labels)[cls_idx] + + cls_idx = np.where(buffer.val_labels == old_cls_label) + cls_images_val, cls_labels_val = np.array(buffer.val_images)[cls_idx], np.array(buffer.val_labels)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + print(f'{old_cls_label}, {len(cls_labels_val[:new_num_per_classes_val])}/{len(cls_labels_val)}') + + # Add new into buffer + for new_cls_label in np.unique(labels_train): + + cls_idx = np.where(labels_train == new_cls_label) + cls_images_train, cls_labels_train = np.array(images_train)[cls_idx], np.array(labels_train)[cls_idx] + + cls_idx = np.where(labels_val == new_cls_label) + cls_images_val, cls_labels_val = np.array(images_val)[cls_idx], np.array(labels_val)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + print(f'{new_cls_label}, {len(cls_labels_val[:new_num_per_classes_val])}/{len(cls_labels_val)}') + + buffer.train_images = preserved_images_train + buffer.train_labels = preserved_labels_train + buffer.val_images = preserved_images_val + buffer.val_labels = preserved_labels_val + + # Val loader + if task_idx == 0: + val_dataloader = None + else: + print(f'Assigning {new_num_per_classes_val} per class in val data') + + val_dataset.images = buffer.val_images + val_dataset.labels = buffer.val_labels + + val_dataloader = DataLoader( + val_dataset, + shuffle=True, + batch_size=100, + num_workers=config['num_workers'], + drop_last=False) + + return train_dataloader, val_dataloader + + def spilt_and_update11(dataloader, buffer, task_idx, config): + + print('using spilt_and_update11') + + train_dataset = deepcopy(dataloader.dataset) + val_dataset = deepcopy(dataloader.dataset) + + # Train_loader + images_train, images_val, labels_train, labels_val = slice_spilt( + train_dataset.images, + train_dataset.labels, + test_size=0.1 + ) + + train_dataset.images = train_dataset.images + buffer.train_images + train_dataset.labels = train_dataset.labels + buffer.train_labels + + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + drop_last=True) + + # Val_loader + if task_idx == 0: + val_dataloader = None + else: + current_num_per_classes_val = min(len(labels_val)//config['inc_cls_num'], (buffer.buffer_size * 0.1) // buffer.total_classes) + print(f'Assigning {current_num_per_classes_val} per class in val data') + + current_num_per_classes_val = int(current_num_per_classes_val) + + selected_images_val, selected_labels_val = [], [] + + for cls_label in np.unique(labels_val): + + cls_idx, = np.where(np.array(labels_val) == cls_label) + cls_images, cls_labels = np.array(images_val)[cls_idx], np.array(labels_val)[cls_idx] + + selected_images_val.extend(cls_images[:current_num_per_classes_val]) + selected_labels_val.extend(cls_labels[:current_num_per_classes_val]) + + print(f'{cls_label}, {len(cls_labels[:current_num_per_classes_val])}/{len(cls_labels)}') + + print(buffer.val_labels) + + for cls_label in range(buffer.total_classes): + cls_idx, = np.where(np.array(buffer.val_labels) == cls_label) + cls_images, cls_labels = np.array(buffer.val_images)[cls_idx], np.array(buffer.val_labels)[cls_idx] + + selected_images_val.extend(cls_images[:current_num_per_classes_val]) + selected_labels_val.extend(cls_labels[:current_num_per_classes_val]) + + print(f'{cls_label}, {len(cls_labels[:current_num_per_classes_val])}/{len(cls_labels)}') + + val_dataset.images = selected_images_val + val_dataset.labels = selected_labels_val + + val_dataloader = DataLoader( + val_dataset, + shuffle=True, + batch_size=100, + num_workers=config['num_workers'], + drop_last=False) + + # Update Buffer + buffer.total_classes += config['init_cls_num'] if task_idx == 0 else config['inc_cls_num'] + new_num_per_classes_train = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.9) + new_num_per_classes_val = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.1) + + preserved_images_train, preserved_labels_train = [], [] + preserved_images_val, preserved_labels_val = [], [] + + # Preserved old in buffer + for old_cls_label in np.unique(buffer.train_labels): + + cls_idx = np.where(buffer.train_labels == old_cls_label) + cls_images_train, cls_labels_train = np.array(buffer.train_images)[cls_idx], np.array(buffer.train_labels)[cls_idx] + + cls_idx = np.where(buffer.val_labels == old_cls_label) + cls_images_val, cls_labels_val = np.array(buffer.val_images)[cls_idx], np.array(buffer.val_labels)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + # Add new into buffer + for new_cls_label in np.unique(labels_train): + + cls_idx = np.where(labels_train == new_cls_label) + cls_images_train, cls_labels_train = np.array(images_train)[cls_idx], np.array(labels_train)[cls_idx] + + cls_idx = np.where(labels_val == new_cls_label) + cls_images_val, cls_labels_val = np.array(images_val)[cls_idx], np.array(labels_val)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + buffer.train_images = preserved_images_train + buffer.train_labels = preserved_labels_train + buffer.val_images = preserved_images_val + buffer.val_labels = preserved_labels_val + + return train_dataloader, val_dataloader + + @staticmethod + def spilt_and_update22(dataloader, buffer, task_idx, config): + + print('using spilt_and_update22') + + train_dataset = deepcopy(dataloader.dataset) + val_dataset = deepcopy(dataloader.dataset) + + # Train_loader + images_train, images_val, labels_train, labels_val = slice_spilt( + train_dataset.images, + train_dataset.labels, + test_size=0.1 + ) + + #train_dataset.images = buffer.train_images + train_dataset.images + #train_dataset.labels = buffer.train_labels + train_dataset.labels + + train_dataset.images = train_dataset.images + buffer.train_images + train_dataset.labels = train_dataset.labels + buffer.train_labels + + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + drop_last=True) + + # Update Buffer + buffer.total_classes += config['init_cls_num'] if task_idx == 0 else config['inc_cls_num'] + new_num_per_classes_train = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.9) + new_num_per_classes_val = int((config['buffer']['kwargs']['buffer_size'] // buffer.total_classes) * 0.1) + + preserved_images_train, preserved_labels_train = [], [] + preserved_images_val, preserved_labels_val = [], [] + + # Preserved old in buffer + for old_cls_label in np.unique(buffer.train_labels): + + cls_idx = np.where(buffer.train_labels == old_cls_label) + cls_images_train, cls_labels_train = np.array(buffer.train_images)[cls_idx], np.array(buffer.train_labels)[cls_idx] + + cls_idx = np.where(buffer.val_labels == old_cls_label) + cls_images_val, cls_labels_val = np.array(buffer.val_images)[cls_idx], np.array(buffer.val_labels)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + print(f'{old_cls_label}, {len(cls_labels_val[:new_num_per_classes_val])}/{len(cls_labels_val)}') + + # Add new into buffer + for new_cls_label in np.unique(labels_train): + + cls_idx = np.where(labels_train == new_cls_label) + cls_images_train, cls_labels_train = np.array(images_train)[cls_idx], np.array(labels_train)[cls_idx] + + cls_idx = np.where(labels_val == new_cls_label) + cls_images_val, cls_labels_val = np.array(images_val)[cls_idx], np.array(labels_val)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + preserved_images_val.extend(cls_images_val[:new_num_per_classes_val]) + preserved_labels_val.extend(cls_labels_val[:new_num_per_classes_val]) + + print(f'{new_cls_label}, {len(cls_labels_val[:new_num_per_classes_val])}/{len(cls_labels_val)}') + + buffer.train_images = preserved_images_train + buffer.train_labels = preserved_labels_train + buffer.val_images = preserved_images_val + buffer.val_labels = preserved_labels_val + + # Val loader + if task_idx == 0: + val_dataloader = None + else: + print(f'Assigning {new_num_per_classes_val} per class in val data') + + val_dataset.images = buffer.val_images + val_dataset.labels = buffer.val_labels + + val_dataloader = DataLoader( + val_dataset, + shuffle=True, + batch_size=100, + num_workers=config['num_workers'], + drop_last=False) + + return train_dataloader, val_dataloader + + @staticmethod + def spilt_and_update4(dataloader, buffer, task_idx, config): + + print('using spilt_and_update4') + + buffer_size = config['buffer']['kwargs']['buffer_size'] + + train_dataset = deepcopy(dataloader.dataset) + val_dataset = deepcopy(dataloader.dataset) + + current_images = train_dataset.images + current_labels = train_dataset.labels + + if task_idx == 0: + + buffer.total_classes += config['init_cls_num'] + new_num_per_classes_train = int(buffer_size * 0.9) // buffer.total_classes + new_num_per_classes_val = int(buffer_size * 0.1) // buffer.total_classes + + ratio = (buffer_size * 0.1) / len(current_labels) + + images_train, images_val, labels_train, labels_val = balance_spilt( + current_images, + current_labels, + test_size=ratio, + random_state=config['seed'] + ) + + # Some Assertions + value_counts = Counter(labels_train) + count1 = next(iter(value_counts.values())) + for value, count in value_counts.items(): + assert count == count1 + + value_counts = Counter(labels_val) + count1 = next(iter(value_counts.values())) + for value, count in value_counts.items(): + assert count == count1 + + + + + + print(ratio, len(labels_train), len(labels_val)) + + train_dataset.images = images_train + train_dataset.labels = labels_train + + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + drop_last=True) + + buffer.val_images = images_val + buffer.val_labels = labels_val + + preserved_images_train, preserved_labels_train = [], [] + for cls_label in range(config['init_cls_num']): + + cls_idx = np.where(labels_train == cls_label) + cls_images_train, cls_labels_train = np.array(images_train)[cls_idx], np.array(labels_train)[cls_idx] + + preserved_images_train.extend(cls_images_train[:new_num_per_classes_train]) + preserved_labels_train.extend(cls_labels_train[:new_num_per_classes_train]) + + buffer.train_images = preserved_images_train + buffer.train_labels = preserved_labels_train + + val_dataloader = None + + else: + + buffer.total_classes += config['inc_cls_num'] + new_num_per_classes_train = int(buffer_size * 0.9) // buffer.total_classes + new_num_per_classes_val = int(buffer_size * 0.1) // buffer.total_classes + + ratio = new_num_per_classes_val * config['inc_cls_num'] / len(current_labels) + + images_train, images_val, labels_train, labels_val = balance_spilt( + current_images, + current_labels, + test_size=ratio, + random_state=config['seed'] + ) + + print(ratio, len(labels_train), len(labels_val)) + + train_dataset.images = images_train + buffer.train_images + train_dataset.labels = labels_train + buffer.train_labels + + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + batch_size=config['batch_size'], + num_workers=config['num_workers'], + drop_last=True) + + buffer.train_images.extend(images_train) + buffer.train_labels.extend(labels_train) + + buffer.val_images.extend(images_val) + buffer.val_labels.extend(labels_val) + + preserved_train_images, preserved_train_labels = [], [] + preserved_val_images, preserved_val_labels = [], [] + for cls_label in range(buffer.total_classes): + + cls_idx = np.where(np.array(buffer.train_labels) == cls_label) + cls_train_images, cls_train_labels = np.array(buffer.train_images)[cls_idx], np.array(buffer.train_labels)[cls_idx] + + cls_idx = np.where(np.array(buffer.val_labels) == cls_label) + cls_val_images, cls_val_labels = np.array(buffer.val_images)[cls_idx], np.array(buffer.val_labels)[cls_idx] + + preserved_train_images.extend(cls_train_images[:new_num_per_classes_train]) + preserved_train_labels.extend(cls_train_labels[:new_num_per_classes_train]) + + preserved_val_images.extend(cls_val_images[:new_num_per_classes_val]) + preserved_val_labels.extend(cls_val_labels[:new_num_per_classes_val]) + + + print(f'{cls_label}, {len(cls_val_labels[:new_num_per_classes_val])}/{len(cls_val_labels)}') + + buffer.train_images = preserved_train_images + buffer.train_labels = preserved_train_labels + buffer.val_images = preserved_val_images + buffer.val_labels = preserved_val_labels + + print(f'Assigning {new_num_per_classes_val} per class in val data') + + val_dataset.images = buffer.val_images + val_dataset.labels = buffer.val_labels + + assert len(buffer.val_labels) == new_num_per_classes_val * buffer.total_classes + + val_dataloader = DataLoader( + val_dataset, + shuffle=True, + batch_size=100, + num_workers=config['num_workers'], + drop_last=False) + + return train_dataloader, val_dataloader + ''' \ No newline at end of file diff --git a/core/model/buffer/__init__.py b/core/model/buffer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa40fb5282e141482feb24b21455a45073c83df4 --- /dev/null +++ b/core/model/buffer/__init__.py @@ -0,0 +1,5 @@ +from .linearbuffer import * +from .update import * +from .linearherdingbuffer import * +from .onlinebuffer import * +from .erbuffer import * \ No newline at end of file diff --git a/core/model/buffer/erbuffer.py b/core/model/buffer/erbuffer.py new file mode 100644 index 0000000000000000000000000000000000000000..9447f788a1f897f28b7b5d3f9b09df3608a82ee6 --- /dev/null +++ b/core/model/buffer/erbuffer.py @@ -0,0 +1,350 @@ +import numpy as np +import torch +import torch.nn as nn + +from collections import OrderedDict +from collections.abc import Iterable + +class ERBuffer(nn.Module): + def __init__(self, capacity): + super().__init__() + + # create placeholders for each item + self.buffers = [] + + self.cap = capacity + self.buffer_size = capacity + self.current_index = 0 + self.n_seen_so_far = 0 + self.is_full = 0 + + # defaults + self.add = self.add_reservoir + self.sample = self.sample_random + + def __len__(self): + return self.current_index + + def add_buffer(self, name, dtype, size): + """ used to add extra containers (e.g. for logit storage) """ + + tmp = torch.zeros(size=(self.cap,) + size, dtype=dtype).to(self.device) + self.register_buffer(f'b{name}', tmp) + self.buffers += [f'b{name}'] + + def _init_buffers(self, batch): + created = 0 + + for name, tensor in batch.items(): + bname = f'b{name}' + if bname not in self.buffers: + + if not type(tensor) == torch.Tensor: + tensor = torch.from_numpy(np.array([tensor])) + + self.add_buffer(name, tensor.dtype, tensor.shape[1:]) + created += 1 + + print(f'created buffer {name}\t {tensor.dtype}, {tensor.shape[1:]}') + + assert created in [0, len(batch)], 'not all buffers created at the same time' + + def add_reservoir(self, batch): + + self._init_buffers(batch) + + n_elem = batch['x'].shape[0] + + place_left = max(0, self.cap - self.current_index) + + indices = torch.FloatTensor(n_elem).to(self.device) + indices = indices.uniform_(0, self.n_seen_so_far).long() + + if place_left > 0: + upper_bound = min(place_left, n_elem) + indices[:upper_bound] = torch.arange(upper_bound) + self.current_index + + valid_indices = (indices < self.cap).long() + idx_new_data = valid_indices.nonzero().squeeze(-1) + idx_buffer = indices[idx_new_data] + + self.n_seen_so_far += n_elem + self.current_index = min(self.n_seen_so_far, self.cap) + + if idx_buffer.numel() == 0: + return + + # perform overwrite op + for name, data in batch.items(): + buffer = getattr(self, f'b{name}') + + if isinstance(data, Iterable): + buffer[idx_buffer] = data[idx_new_data] + else: + buffer[idx_buffer] = data + + def add_balanced(self, batch): + self._init_buffers(batch) + + n_elem = batch['x'].size(0) + + # increment first + self.n_seen_so_far += n_elem + self.current_index = min(self.n_seen_so_far, self.cap) + + # first thing is we just add all the data + for name, data in batch.items(): + buffer = getattr(self, f'b{name}') + + if not isinstance(data, Iterable): + data = buffer.new(size=(n_elem, *buffer.shape[1:])).fill_(data) + + buffer = torch.cat((data, buffer))[:self.n_seen_so_far] + setattr(self, f'b{name}', buffer) + + n_samples_over = buffer.size(0) - self.cap + + # no samples to remove + if n_samples_over <= 0: + return + + # remove samples from the most common classes + class_count = self.by.bincount() + rem_per_class = torch.zeros_like(class_count) + + while rem_per_class.sum() < n_samples_over: + max_idx = class_count.argmax() + rem_per_class[max_idx] += 1 + class_count[max_idx] -= 1 + + # always remove the oldest samples for each class + classes_trimmed = rem_per_class.nonzero().flatten() + idx_remove = [] + + for cls in classes_trimmed: + cls_idx = (self.by == cls).nonzero().view(-1) + idx_remove += [cls_idx[-rem_per_class[cls]:]] + + idx_remove = torch.cat(idx_remove) + idx_mask = torch.BoolTensor(buffer.size(0)).to(self.device) + idx_mask.fill_(0) + idx_mask[idx_remove] = 1 + + # perform overwrite op + for name, data in batch.items(): + buffer = getattr(self, f'b{name}') + buffer = buffer[~idx_mask] + setattr(self, f'b{name}', buffer) + + def add_queue(self, batch): + self._init_buffers(batch) + + if not hasattr(self, 'queue_ptr'): + self.queue_ptr = 0 + + start_idx = self.queue_ptr + end_idx = (start_idx + batch['x'].size(0)) % self.cap + + for name, data in batch.items(): + buffer = getattr(self, f'b{name}') + buffer[start_idx:end_idx] = data + + def sample_random(self, amt, exclude_task=None, **kwargs): + buffers = OrderedDict() + + if exclude_task is not None: + assert hasattr(self, 'bt') + valid_indices = torch.where(self.bt != exclude_task)[0] + valid_indices = valid_indices[valid_indices < self.current_index] + for buffer_name in self.buffers: + buffers[buffer_name[1:]] = getattr(self, buffer_name)[valid_indices] + else: + for buffer_name in self.buffers: + buffers[buffer_name[1:]] = getattr(self, buffer_name)[:self.current_index] + + n_selected = buffers['x'].size(0) + if n_selected <= amt: + assert n_selected > 0 + return buffers + else: + idx_np = np.random.choice(buffers['x'].size(0), amt, replace=False) + indices = torch.from_numpy(idx_np).to(self.bx.device) + + return OrderedDict({k:v[indices] for (k,v) in buffers.items()}) + + def sample_balanced(self, amt, exclude_task=None, **kwargs): + buffers = OrderedDict() + + if exclude_task is not None: + assert hasattr(self, 'bt') + valid_indices = (self.bt != exclude_task).nonzero().squeeze() + for buffer_name in self.buffers: + buffers[buffer_name[1:]] = getattr(self, buffer_name)[valid_indices] + else: + for buffer_name in self.buffers: + buffers[buffer_name[1:]] = getattr(self, buffer_name)[:self.current_index] + + class_count = buffers['y'].bincount() + + # a sample's prob. of being sample is inv. prop to its class abundance + class_sample_p = 1. / class_count.float() / class_count.size(0) + per_sample_p = class_sample_p.gather(0, buffers['y']) + indices = torch.multinomial(per_sample_p, amt) + + return OrderedDict({k:v[indices] for (k,v) in buffers.items()}) + + def sample_pos_neg(self, inc_data, task_free=True, same_task_neg=True): + + x = inc_data['x'] + label = inc_data['y'] + task = torch.zeros_like(label).fill_(inc_data['t']) + + # we need to create an "augmented" buffer containing the incoming data + bx = torch.cat((self.bx[:self.current_index], x)) + by = torch.cat((self.by[:self.current_index], label)) + bt = torch.cat((self.bt[:self.current_index], task)) + bidx = torch.arange(bx.size(0)).to(bx.device) + + # buf_size x label_size + same_label = label.view(1, -1) == by.view(-1, 1) + same_task = task.view(1, -1) == bt.view(-1, 1) + same_ex = bidx[-x.size(0):].view(1, -1) == bidx.view(-1, 1) + + task_labels = label.unique() + real_same_task = same_task + + if task_free: + same_task = torch.zeros_like(same_task) + + for label_ in task_labels: + label_exp = label_.view(1, -1).expand_as(same_task) + same_task = same_task | (label_exp == by.view(-1, 1)) + + valid_pos = same_label & ~same_ex + + if same_task_neg: + valid_neg = ~same_label & same_task + else: + valid_neg = ~same_label + + # remove points which don't have pos, neg from same and diff t + has_valid_pos = valid_pos.sum(0) > 0 + has_valid_neg = valid_neg.sum(0) > 0 + + invalid_idx = ~has_valid_pos | ~has_valid_neg + + if invalid_idx.sum() > 0: + # so the fetching operation won't fail + valid_pos[:, invalid_idx] = 1 + valid_neg[:, invalid_idx] = 1 + + # easier if invalid_idx is a binary tensor + is_invalid = torch.zeros_like(label).bool() + is_invalid[invalid_idx] = 1 + + # fetch positive samples + pos_idx = torch.multinomial(valid_pos.float().T, 1).squeeze(1) + neg_idx = torch.multinomial(valid_neg.float().T, 1).squeeze(1) + + n_fwd = torch.stack((bidx[-x.size(0):], pos_idx, neg_idx), 1)[~invalid_idx].unique().size(0) + + return bx[pos_idx], \ + bx[neg_idx], \ + by[pos_idx], \ + by[neg_idx], \ + is_invalid, \ + n_fwd + + def sample_minimal_pos_neg(self, inc_data, task_free=True, same_task_neg=True): + """ maximize choosing the incoming data to minimize forward passes """ + + x = inc_data['x'] + label = inc_data['y'] + task = torch.zeros_like(label).fill_(inc_data['t']) + + ''' + # we need to create an "augmented" buffer containing the incoming data + bx = torch.cat((self.bx[:self.current_index], x)) + by = torch.cat((self.by[:self.current_index], label)) + bt = torch.cat((self.bt[:self.current_index], task)) + bidx = torch.arange(bx.size(0)).to(bx.device) + + # buf_size x label_size + same_label = label.view(1, -1) == by.view(-1, 1) + same_task = task.view(1, -1) == bt.view(-1, 1) + same_ex = bidx[-x.size(0):].view(1, -1) == bidx.view(-1, 1) + ''' + + bidx = torch.arange(x.size(0)).to(x.device) + + # label_size x label_size + same_label = label.view(1, -1) == label.view(-1, 1) + same_task = task.view(1, -1) == task.view(-1, 1) + same_ex = bidx.view(1, -1) == bidx.view(-1, 1) + + task_labels = label.unique() + real_same_task = same_task + + # TASK FREE METHOD : instead of using the task ID, we'll use labels in + # the current batch to mimic task + if task_free: + same_task = torch.zeros_like(same_task) + + for label_ in task_labels: + label_exp = label_.view(1, -1).expand_as(same_task) + same_task = same_task | (label_exp == label.view(-1, 1)) + + valid_pos = same_label & ~same_ex + + if same_task_neg: + valid_neg = ~same_label & same_task + else: + valid_neg = ~same_label + + # remove points which don't have pos, neg from same and diff t + has_valid_pos = valid_pos.sum(0) > 0 + has_valid_neg = valid_neg.sum(0) > 0 + + invalid_idx = ~has_valid_pos | ~has_valid_neg + + if invalid_idx.any(): + # so the fetching operation won't fail + valid_pos[:, invalid_idx] = 1 + valid_neg[:, invalid_idx] = 1 + + # easier if invalid_idx is a binary tensor + is_invalid = torch.zeros_like(label).bool() + is_invalid[invalid_idx] = 1 + + # fetch positive samples + pos_idx = torch.multinomial(valid_pos.float().T, 1).squeeze(1) + neg_idx = torch.multinomial(valid_neg.float().T, 1).squeeze(1) + + # return + pos_x, neg_x = x[pos_idx], x[neg_idx] + pos_y, neg_y = label[pos_idx], label[neg_idx] + + n_fwd = torch.stack((bidx, pos_idx, neg_idx), 1)[~invalid_idx].unique().size(0) + + # --- handle cases that can be solved by looking into the buffer: + if invalid_idx.any(): + # build new input + invalid_data = OrderedDict() + invalid_data['x'] = x[invalid_idx] + invalid_data['y'] = label[invalid_idx] + invalid_data['t'] = inc_data['t'] + + n_pos_x, n_neg_x, n_pos_y, n_neg_y, n_is_invalid, n_new_fwd = \ + self.sample_pos_neg(invalid_data, task_free=task_free, same_task_neg=same_task_neg) + + # next we fill the invalid indices with their potentially valid points from the buffer + pos_x[invalid_idx][~n_is_invalid].data.copy_(n_pos_x[~n_is_invalid]) + neg_x[invalid_idx][~n_is_invalid].data.copy_(n_neg_x[~n_is_invalid]) + pos_y[invalid_idx][~n_is_invalid].data.copy_(n_pos_y[~n_is_invalid]) + neg_y[invalid_idx][~n_is_invalid].data.copy_(n_neg_y[~n_is_invalid]) + + invalid_idx[invalid_idx].data.copy_(n_is_invalid) + + n_fwd += n_new_fwd + + return pos_x, neg_x, pos_y, neg_y, is_invalid, n_fwd diff --git a/core/model/buffer/linearbuffer.py b/core/model/buffer/linearbuffer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7cdff91603e479f4d4e28c8791948b516e0740 --- /dev/null +++ b/core/model/buffer/linearbuffer.py @@ -0,0 +1,28 @@ +import numpy as np + + +class LinearBuffer: + def __init__(self, buffer_size, strategy, batch_size): + + self.buffer_size = buffer_size + self.strategy = strategy + self.batch_size = batch_size + self.total_classes = 0 + self.images, self.labels = [], [] + + def is_empty(self): + return len(self.labels) == 0 + +class LinearSpiltBuffer: + def __init__(self, buffer_size, strategy, batch_size, val_ratio): + + self.buffer_size = buffer_size + self.strategy = strategy + self.batch_size = batch_size + self.val_ratio = 0.1 + self.total_classes = 0 + self.train_images, self.train_labels = [], [] + self.val_images, self.val_labels = [], [] + + def is_empty(self): + return len(self.train_labels) == 0 \ No newline at end of file diff --git a/core/model/buffer/linearherdingbuffer.py b/core/model/buffer/linearherdingbuffer.py new file mode 100644 index 0000000000000000000000000000000000000000..3197824b7943529bd7433be4ad920fafc29b3d80 --- /dev/null +++ b/core/model/buffer/linearherdingbuffer.py @@ -0,0 +1,165 @@ +import numpy as np +import torch +import torch.nn as nn +import PIL +import os +from typing import List +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +class LinearHerdingBuffer: + def __init__(self, buffer_size, batch_size): + self.buffer_size = buffer_size + self.strategy = None + self.batch_size = batch_size + self.images, self.labels = [], [] + self.total_classes = 0 + + def is_empty(self): + return len(self.labels) == 0 + + def clear(self): + # clear the buffer + del self.images + del self.labels + self.images = [] + self.labels = [] + + def get_all_data(self): + # return images and labels in the format of np.array + return np.array(self.images), np.array(self.labels) + + def add_data(self, data:List[str], targets:List[str]): + # add data and its labels to the buffer + self.images.extend(data) + self.labels.extend(targets) + + + def update(self, model:nn.Module, train_loader, val_transform, task_idx:int, + total_cls_num:int, cur_cls_indexes, device): + + # get the chosen global index in the dataset for buffer + chosen_indexes = self.herding_select(model, train_loader, val_transform, + task_idx, total_cls_num, cur_cls_indexes, + device) + + cur_task_dataset = train_loader.dataset + new_images = [] + new_labels = [] + for i in chosen_indexes: + new_images.append(cur_task_dataset.images[i]) + new_labels.append(cur_task_dataset.labels[i]) + + self.add_data(new_images, new_labels) + + def reduce_old_data(self, task_idx:int, total_cls_num:int) -> None: + # subsample previous categories in the buffer + samples_per_class = self.buffer_size // total_cls_num + + if samples_per_class == 0: + print( + f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ", + f"Samples per class will be set to 1, to avoid empty buffer." + ) + samples_per_class = 1 + + if task_idx > 0: + buffer_X, buffer_Y = self.get_all_data() + self.clear() + for y in np.unique(buffer_Y): + idx = (buffer_Y == y) + selected_X, selected_Y = buffer_X[idx], buffer_Y[idx] + self.add_data( + data=selected_X[:samples_per_class], + targets=selected_Y[:samples_per_class], + ) + + + def herding_select(self, model:nn.Module, train_loader, val_transform, + task_idx:int, total_cls_num:int, cur_cls_indexes, device): + + # Remove buffer samples from the dataset + # and keep only the samples belonging to the current task category. + def remove_buffer_sample_in_dataset(dataset, cur_cls_indexes): + new_labels = [] + new_images = [] + for i in cur_cls_indexes: + ind = np.array(dataset.labels) == i + new_images.extend(list(np.array(dataset.images)[ind])) + new_labels.extend(list(np.array(dataset.labels)[ind])) + dataset.labels = new_labels + dataset.images = new_images + + # get dataset containing buffer samples + dataset = train_loader.dataset + + # remove buffer samples and only keep + remove_buffer_sample_in_dataset(dataset, cur_cls_indexes) + + # reset the transform + dataset.trfms = val_transform + + # get loader for herding + loader = DataLoader( + dataset, + # Note that `shuffle = False` should be set. + # otherwise otherwise the generated indexes will not match with the paths of the images + shuffle = False, + batch_size = 32, + # `drop_last = False` should be set as False, otherwise some samples are lost + drop_last = False + ) + + # how many sample per class do we want + samples_per_class = self.buffer_size // total_cls_num + if samples_per_class == 0: + print( + f"Warning: Buffer size ({self.buffer_size}) is too small for total classes ({total_cls_num}). ", + f"Samples per class will be set to 1, to avoid empty buffer." + ) + samples_per_class = 1 + + + # compute feature for all training sample for all train samples + extracted_features = [] + extracted_targets = [] + # print("!!!!! The origin code is\'feats = model.backbone(image)['features'] \', change to \'feats = model.extract_vector(image) \' by WA") + with torch.no_grad(): + model.eval() + for data in loader: + image = data['image'].to(device) + label = data['label'].to(device) + # feats = model.extract_vector(image) + feats = model.backbone(image)['features'] + feats = feats / feats.norm(dim=1).view(-1, 1) # Feature normalization + extracted_features.append(feats) + extracted_targets.append(label) + extracted_features = (torch.cat(extracted_features)).cpu() + extracted_targets = (torch.cat(extracted_targets)).cpu() + + result = [] + for curr_cls in np.unique(extracted_targets): + + cls_ind = np.where(extracted_targets == curr_cls)[0] + cls_feats = extracted_features[cls_ind] + mean_feat = cls_feats.mean(0, keepdim=True) + running_sum = torch.zeros_like(mean_feat) + i = 0 + begin_index = cls_ind[0] + while i < samples_per_class and i < cls_feats.shape[0]: + cost = (mean_feat - (cls_feats + running_sum) / (i + 1)).norm(2, 1) + + # Notice that the initial offset should be added + # since indexes we want are global in the dataset + # hence we should guarantee indexes belonging to the same class + # should be continuous + idx_min = cost.argmin().item() + global_index = idx_min + begin_index + result.append(global_index) + running_sum += cls_feats[idx_min:idx_min + 1] + cls_feats[idx_min] = cls_feats[idx_min] + 1e6 + i += 1 + + return result + + diff --git a/core/model/buffer/onlinebuffer.py b/core/model/buffer/onlinebuffer.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5919a363b7e78ec513ba77e78d8e175a2397b6 --- /dev/null +++ b/core/model/buffer/onlinebuffer.py @@ -0,0 +1,120 @@ +import numpy as np +import pdb +import torch +import torch.nn as nn +import torch.nn.functional as F + +# modified from https://github.com/gydpku/OCM/blob/main/buffer.py + +class OnlineBuffer(nn.Module): + def __init__(self, buffer_size, batch_size, input_size): + super().__init__() + + self.place_left = True + self.strategy = None + self.buffer_size = buffer_size + print('buffer has %d slots' % buffer_size, buffer_size) + + buf_data = torch.FloatTensor(buffer_size, *input_size).fill_(0) + buf_targets = torch.LongTensor(buffer_size).fill_(0) + buf_tasks = torch.LongTensor(buffer_size).fill_(0) + + self.current_index = 0 + self.n_seen_so_far = 0 + self.is_full = 0 + self.total_classes = 0 + # registering as buffer allows us to save the object using `torch.save` + self.register_buffer('buf_data', buf_data) + self.register_buffer('buf_targets', buf_targets) + self.register_buffer('buf_tasks', buf_tasks) + + + def tensor_to_device(self, device): + self.device = device + self.buf_data.to(device) + self.buf_targets.to(device) + self.buf_tasks.to(device) + + + + def add_reservoir(self, x, y, task): + n_elem = x.size(0) + + self.device = x.device + place_left = max(0, self.buffer_size - self.current_index) + offset = min(place_left, n_elem) + + if place_left: + offset = min(place_left, n_elem) + + self.buf_data[self.current_index: self.current_index + offset].data.copy_(x[:offset]) + self.buf_targets[self.current_index: self.current_index + offset].data.copy_(y[:offset]) + self.buf_tasks[self.current_index: self.current_index + offset].fill_(task) + self.current_index += offset + self.n_seen_so_far += offset + + if offset == x.size(0): + return + + self.place_left = False + + # remove what is already in the buffer + x, y = x[place_left:], y[place_left:] + + indices = torch.FloatTensor(x.size(0)).to(x.device).uniform_(0, self.n_seen_so_far).long() + valid_indices = (indices < self.buf_data.size(0)).long() + + idx_new_data = valid_indices.nonzero().squeeze(-1) + idx_buffer = indices[idx_new_data] + + self.n_seen_so_far += x.size(0) + + if idx_buffer.numel() == 0: + return + + assert idx_buffer.max() < self.buf_data.size(0), pdb.set_trace() + assert idx_buffer.max() < self.buf_targets.size(0), pdb.set_trace() + assert idx_buffer.max() < self.buf_tasks.size(0), pdb.set_trace() + + assert idx_new_data.max() < x.size(0), pdb.set_trace() + assert idx_new_data.max() < y.size(0), pdb.set_trace() + + if self.buf_data.device != x.device: + self.buf_data = self.buf_data.to(x.device) + self.buf_targets = self.buf_targets.to(x.device) + self.buf_tasks = self.buf_tasks.to(x.device) + + self.buf_data[idx_buffer] = x[idx_new_data] + self.buf_targets[idx_buffer] = y[idx_new_data] + self.buf_tasks[idx_buffer] = task + + + + + def sample(self, amount, exclude_task = None, ret_ind = False): + + if self.buf_data.device != self.device: + self.buf_data = self.buf_data.to(self.device) + self.buf_targets = self.buf_targets.to(self.device) + self.buf_tasks = self.buf_tasks.to(self.device) + + if exclude_task is not None: + valid_indices = (self.t != exclude_task) + valid_indices = valid_indices.nonzero().squeeze() + bx, by, bt = self.buf_data[valid_indices], self.buf_targets[valid_indices], self.buf_tasks[valid_indices] + else: + bx, by, bt = self.buf_data[:self.current_index], self.buf_targets[:self.current_index], self.buf_tasks[:self.current_index] + + if bx.size(0) < amount: + if ret_ind: + return bx, by, bt, torch.from_numpy(np.arange(bx.size(0))) + else: + return bx, by, bt + else: + indices = torch.from_numpy(np.random.choice(bx.size(0), amount, replace=False)) + indices = indices.to(self.device) + + if ret_ind: + return bx[indices], by[indices], bt[indices], indices + else: + return bx[indices], by[indices], bt[indices] \ No newline at end of file diff --git a/core/model/buffer/update.py b/core/model/buffer/update.py new file mode 100644 index 0000000000000000000000000000000000000000..a73e58faf9537f36ece0188c44f4eeeff588a48d --- /dev/null +++ b/core/model/buffer/update.py @@ -0,0 +1,85 @@ +import numpy as np +import torch +import copy +from collections import Counter +from torch.utils.data import DataLoader + +def random_update(datasets, buffer): + + images = np.array(datasets.images + buffer.images) + labels = np.array(datasets.labels + buffer.labels) + perm = np.random.permutation(len(labels)) + + images, labels = images[perm[:buffer.buffer_size]], labels[perm[:buffer.buffer_size]] + + buffer.images = images.tolist() + buffer.labels = labels.tolist() + +def herding_update(datasets, buffer, feature_extractor, device): + + print("Using Herding Update Strategy") + + per_classes = buffer.buffer_size // buffer.total_classes + + selected_images, selected_labels = [], [] + images = np.array(datasets.images + buffer.images) + labels = np.array(datasets.labels + buffer.labels) + + for cls in range(buffer.total_classes): + cls_images_idx = np.where(labels == cls) + cls_images, cls_labels = images[cls_images_idx], labels[cls_images_idx] + + cls_selected_images, cls_selected_labels = construct_examplar(copy.copy(datasets), cls_images, cls_labels, feature_extractor, per_classes, device) + selected_images.extend(cls_selected_images) + selected_labels.extend(cls_selected_labels) + + label_counter = Counter(buffer.labels) + print("\nBuffer composition per class:") + for cls in sorted(label_counter.keys()): + print(f" Class {cls:3d} : {label_counter[cls]} samples") + + buffer.images, buffer.labels = selected_images, selected_labels + +def construct_examplar(datasets, images, labels, feature_extractor, per_classes, device): + if len(images) <= per_classes: + print(labels[0], len(images), per_classes) + return images, labels + + datasets.images, datasets.labels = images, labels + dataloader = DataLoader(datasets, shuffle = False, batch_size = 32, drop_last = False) + + with torch.no_grad(): + features = [] + for data in dataloader: + imgs = data['image'].to(device) + features.append(feature_extractor(imgs)['features'].cpu().numpy().tolist()) + + features = np.concatenate(features) + selected_images, selected_labels = [], [] + selected_features = [] + class_mean = np.mean(features, axis=0) + + for k in range(1, per_classes+1): + if len(selected_features) == 0: + S = np.zeros_like(features[0]) + else: + S = np.mean(np.array(selected_features), axis=0) + + + mu_p = (S + features) / k + i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) + + selected_images.append(images[i]) + selected_labels.append(labels[i]) + selected_features.append(features[i]) + + features = np.delete(features, i, axis=0) + images = np.delete(images, i) + labels = np.delete(labels, i) + + return selected_images, selected_labels + + + + + diff --git a/core/model/cl_lora.py b/core/model/cl_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..3919efa97f49ed50e8a140f64e030fcfd08c3839 --- /dev/null +++ b/core/model/cl_lora.py @@ -0,0 +1,344 @@ +""" +@article{He_2025_CVPR, + author = {He, Jiangpeng and Duan, Zhihao and Zhu, Fengqing}, + title = {CL-LoRA: Continual Low-Rank Adaptation for Rehearsal-Free Class-Incremental Learning}, + journal = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, + month = {June}, + year = {2025}, + pages = {30534-30544} +} + +Adapted from https://github.com/JiangpengHe/CL-LoRA +""" + +import math +import torch + +import numpy as np +import torch.nn as nn + +from tqdm import tqdm +from torch import optim +from copy import deepcopy +from torch.nn import functional as F + +from .backbone.transformer import MultiHeadAttention_CL_LoRA + +def _KD_loss(pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] + +def compute_orthogonality_loss(previous_weights_list, current_weights, epsilon=1e-8): + total_ortho_loss = 0.0 + current_norm = torch.norm(current_weights.flatten()) + current_normalized = current_weights.flatten() / (current_norm + epsilon) + + for prev_weights in previous_weights_list: + # Normalize previous weights + prev_norm = torch.norm(prev_weights.flatten()) + prev_normalized = prev_weights.flatten() / (prev_norm + epsilon) + + # Compute absolute dot product (should be close to 0 for orthogonal vectors) + dot_product = torch.abs(torch.sum(prev_normalized * current_normalized)) + + total_ortho_loss += dot_product + + # Average over all previous tasks + if len(previous_weights_list) > 0: + total_ortho_loss /= len(previous_weights_list) + + return total_ortho_loss + +class CosineLinearFeature(nn.Module): + def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True): + super(CosineLinearFeature, self).__init__() + self.in_features = in_features + self.out_features = out_features * nb_proxy + self.nb_proxy = nb_proxy + self.to_reduce = to_reduce + self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) + if sigma: + self.sigma = nn.Parameter(torch.Tensor(1)) + else: + self.register_parameter('sigma', None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.sigma is not None: + self.sigma.data.fill_(1) + + def reset_parameters_to_zero(self): + self.weight.data.fill_(0) + + def forward(self, input): + out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) + + if self.to_reduce: + # Reduce_proxy + out = reduce_proxies(out, self.nb_proxy) + + if self.sigma is not None: + out = self.sigma * out + + return {'logits': out} + + def forward_diagonal(self, input, cur_task, alpha=0., beta=0.0, init_cls=10, inc=10, out_dim=768, use_init_ptm=False): + for i in range(cur_task + 1): + if i == 0: + start_cls = 0 + end_cls = init_cls + else: + start_cls = init_cls + (i - 1) * inc + end_cls = start_cls + inc + input1 = F.normalize(input[:, i * out_dim:(i + 1) * out_dim], p=2, dim=1) + weight1 = F.normalize(self.weight[start_cls:end_cls, i * out_dim:(i + 1) * out_dim], p=2, dim=1) + + out = F.linear(input1, weight1) + if i == 0: + out_all = out + else: + out_all = torch.cat((out_all, out), dim=1) if i != 0 else out + + if self.to_reduce: + # Reduce_proxy + out_all = reduce_proxies(out_all, self.nb_proxy) + + if self.sigma is not None: + out_all = self.sigma * out_all + + return {'logits': out_all} + +class Model(nn.Module): + def __init__(self, backbone, device, **kwargs): + super().__init__() + self.backbone = backbone + self.inc = kwargs["inc_cls_num"] + self.init_cls = kwargs["init_cls_num"] + self._cur_task = -1 + self.out_dim = 768 + self.fc = None + self.alpha = 0. + self.beta = 0 + self.fc_list = nn.ModuleList() + self.fc_list_task = nn.ModuleList() + self.adapter_list = nn.ModuleList() + self.init_proto = None + + self._device = device + + def freeze(self): + for name, param in self.named_parameters(): + param.requires_grad = False + + @property + def feature_dim(self): + + return self.out_dim * (self._cur_task + 1) + + def update_fc(self, nb_classes): + self._cur_task += 1 + + if self._cur_task == 0: + self.proxy_fc = self.generate_fc(self.out_dim, self.init_cls).to(self._device) + else: + self.proxy_fc = self.generate_fc(self.out_dim, self.inc).to(self._device) + init_proto = self.generate_fc(self.out_dim, nb_classes).to(self._device) + + if self.init_proto is not None: + old_nb_classes = self.init_proto.out_features + weight = deepcopy(self.init_proto.weight.data) + init_proto.weight.data[: old_nb_classes, :] = nn.Parameter(weight) + del self.init_proto + self.init_proto = init_proto + + fc = self.generate_fc(self.feature_dim, nb_classes).to(self._device) + fc.reset_parameters_to_zero() + + if self.fc is not None: + old_nb_classes = self.fc.out_features + weight = deepcopy(self.fc.weight.data) + fc.sigma.data = self.fc.sigma.data + fc.weight.data[: old_nb_classes, : -self.out_dim] = nn.Parameter(weight) + + del self.fc + self.fc = fc + self.fc.requires_grad_(False) + + def add_fc(self): + self.fc_list.append(self.proxy_fc.requires_grad_(False)) + del self.proxy_fc + + def generate_fc(self, in_dim, out_dim): + fc = CosineLinearFeature(in_dim, out_dim) + return fc + + def forward_kd(self, x, t_idx): + x_new, x_teacher = self.backbone.forward_general_cls(x, t_idx) + out_new, out_teacher = self.proxy_fc(x_new), self.proxy_fc(x_teacher) + return out_new, out_teacher + + def forward(self, x, test=False): + if test == False: + x = self.backbone.forward(x, test=False) + out = self.proxy_fc(x) + out.update({"features": x}) + return out + else: + + x_input = self.backbone.forward(x, test=True) + out = self.fc.forward_diagonal(x_input, cur_task=self._cur_task, alpha=0., init_cls=self.init_cls, inc=self.inc, use_init_ptm=False, beta=0) + out.update({"features": x_input}) + + return out + +class CL_LoRA(nn.Module): + + def __init__(self, backbone, device, **kwargs): + + super().__init__() + + self.device = device + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.task_num = kwargs["task_num"] + self._known_classes = 0 + self._total_classes = 0 + self._cur_task = 0 + + self._network = Model(backbone, device, **kwargs) + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_CL_LoRA)] + + self.lora_modules = [[] for _ in range(self.task_num)] + self.lora_scale_weights = [[] for _ in range(self.task_num)] + self.optim = None + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + aux_targets = y - self._known_classes + + logits = self._network(x, test=False)['logits'] + loss = F.cross_entropy(logits, aux_targets) + + if self._cur_task > 0: + + kd_ratio = 5. + Temperature = 2 + + out_new, out_teacher = self._network.forward_kd(x, self._cur_task) + out_new_logits = out_new["logits"] + out_teacher_logits = out_teacher["logits"] + loss_kd = kd_ratio * _KD_loss(out_new_logits, out_teacher_logits, T=Temperature) + + self.optim.zero_grad() + loss_kd.backward() + + for j in range(len(self._network.backbone.feat.general_pos)): + pos = self._network.backbone.feat.adapt_pos.index(self._network.backbone.feat.general_pos[j]) + for jj in range(len(self._network.backbone.feat.msa)): + if self._network.backbone.feat.msa[jj] == 1: + temp_weights = 1. * torch.norm(self._network.backbone.feat.old_adapter_list[self._cur_task-1][pos][jj].lora_A.weight,dim=1) + temp_weights = 1. * len(temp_weights) * temp_weights / torch.sum(temp_weights) + self._network.backbone.feat.cur_adapter[pos][jj].lora_A.weight.grad = temp_weights.unsqueeze(1) * self._network.backbone.feat.cur_adapter[pos][jj].lora_A.weight.grad + + self.optim.step() + + orth_loss_specific = compute_orthogonality_loss(self._network.backbone.feat.block_weight_list, self._network.backbone.feat.block_weight) + loss += 0.0001 * orth_loss_specific + + preds = logits.max(1)[1] + correct_count = preds.eq(aux_targets).sum().item() + acc = correct_count / y.size(0) + + return preds, acc, loss + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + logits = self._network(x, True)["logits"] + preds = logits.max(1)[1] + + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + if task_idx > 0: + self._known_classes = self._total_classes + self._network.freeze() + self._network.backbone.add_adapter_to_list() + + self._cur_task = task_idx + self._total_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + self._network.update_fc(self._total_classes) + + for name, param in self._network.named_parameters(): + if 'backbone.feat.cur_adapter' in name or 'proxy_fc.' in name or 'init_proto' in name: + param.requires_grad_(True) + else: + param.requires_grad_(False) + + param.requires_grad_(False) + + if 'lora' in name and 'cur_adapter' in name: + if any(f'er.{i}.' in name for i in range(6)) and 'lora_B' in name and 'cur_adapter': + pass + else: + param.requires_grad_(True) + + elif f'proxy_fc' in name: + param.requires_grad_(True) + elif 'init_proto' in name: + param.requires_grad_(True) + elif 'block_weight' in name and 'old' not in name: + param.requires_grad_(True) + + self._network = self._network.to(self.device) + + @torch.no_grad() + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + self._network.add_fc() + train_loader.dataset.trfms = test_loaders[0].dataset.trfms + self.replace_fc(train_loader) + + self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + def replace_fc(self, train_loader): + model = self._network + model = model.eval() + + with torch.no_grad(): + for index in range(0, self._cur_task + 1): + embedding_list, label_list = [], [] + for i, batch in enumerate(train_loader): + data, label = batch['image'], batch['label'] + data = data.to(self.device) + label = label.to(self.device) + embedding = model.backbone.forward_proto(data, adapt_index=index) + embedding_list.append(embedding.cpu()) + label_list.append(label.cpu()) + + embedding_list = torch.cat(embedding_list, dim=0) + label_list = torch.cat(label_list, dim=0) + + class_list = np.unique(train_loader.dataset.labels) + for class_index in class_list: + data_index = (label_list == class_index).nonzero().squeeze(-1) + embedding = embedding_list[data_index] + proto = embedding.mean(0) + model.fc.weight.data[class_index, index*self._network.out_dim:(index+1)*self._network.out_dim] = proto + + def get_parameters(self, config): + return self._network.parameters() + + def set_optim(self, optim): + self.optim = optim \ No newline at end of file diff --git a/core/model/codaprompt.py b/core/model/codaprompt.py new file mode 100644 index 0000000000000000000000000000000000000000..63d885e569745ccbfc4d4361e1e23851c31f025f --- /dev/null +++ b/core/model/codaprompt.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{DBLP:conf/cvpr/SmithKGCKAPFK23, + author = {James Seale Smith and + Leonid Karlinsky and + Vyshnavi Gutta and + Paola Cascante{-}Bonilla and + Donghyun Kim and + Assaf Arbelle and + Rameswar Panda and + Rog{\'{e}}rio Feris and + Zsolt Kira}, + title = {CODA-Prompt: COntinual Decomposed Attention-Based Prompting for Rehearsal-Free + Continual Learning}, + booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, + {CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023}, + pages = {11909--11919}, + publisher = {{IEEE}}, + year = {2023} +} + +https://arxiv.org/abs/2211.13218 + +Adapted from https://github.com/GT-RIPL/CODA-Prompt +""" + +import math +import copy +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +from .finetune import Finetune +from core.model.backbone.resnet import * +import numpy as np +from torch.utils.data import DataLoader + + +class Model(nn.Module): + # A model consists with a backbone and a classifier + def __init__(self, backbone, feat_dim, num_class): + super().__init__() + self.backbone = backbone + self.feat_dim = feat_dim + self.num_class = num_class + self.classifier = nn.Linear(feat_dim, num_class) + + def forward(self, x, train=True): + if train: + feat, loss = self.backbone(x, train=True) + return self.classifier(feat), loss + else: + feat = self.backbone(x, train=False) + return self.classifier(feat) + + +class CodaPrompt(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num']) + self.network.backbone.create_prompt('coda', n_tasks = kwargs['task_num'], prompt_param=[kwargs['pool_size'], kwargs['prompt_length'], kwargs['mu']]) + self.task_idx = 0 + self.kwargs = kwargs + + self.last_out_dim = 0 + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + self.network.backbone.task_id = task_idx + + in_features = self.network.classifier.in_features + out_features = self.network.classifier.out_features + new_out_features = self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num'] + new_fc = nn.Linear(in_features, new_out_features) + new_fc.weight.data[:out_features] = self.network.classifier.weight.data + new_fc.bias.data[:out_features] = self.network.classifier.bias.data + self.network.classifier = new_fc + self.network.to(self.device) + + self.loss_fn = nn.CrossEntropyLoss(reduction='none') + + self.out_dim = new_out_features + self.dw_k = torch.tensor(np.ones(self.out_dim + 1, dtype=np.float32)).to(self.device) + + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logit, loss = self.network(x, train=True) + + logit[:,:self.last_out_dim] = -float('inf') + dw_cls = self.dw_k[-1 * torch.ones(y.size()).long()] + + loss += (self.loss_fn(logit, y) * dw_cls).mean() + + pred = torch.argmax(logit, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0), loss + + + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + self.last_out_dim = self.out_dim + + def inference(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + logit = self.network(x, train=False) + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + + def get_parameters(self, config): + return list(self.network.backbone.prompt.parameters()) + list(self.network.classifier.parameters()) diff --git a/core/model/dap.py b/core/model/dap.py new file mode 100644 index 0000000000000000000000000000000000000000..a53d2552cd99851fa883b09c3e734d2a9bbe3070 --- /dev/null +++ b/core/model/dap.py @@ -0,0 +1,200 @@ +""" +@inproceedings{10.24963/ijcai.2024/456, + author = {Hong, Chenxing and Jin, Yan and Kang, Zhiqi and Chen, Yizhou and Li, Mengke and Lu, Yang and Wang, Hanzi}, + title = {Dynamically anchored prompting for task-imbalanced continual learning}, + booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence}, + year = {2025}, +} +https://dl.acm.org/doi/10.24963/ijcai.2024/456 +Adapted from https://github.com/chenxing6666/dap +""" + +import math +import copy +import torch +import torch.nn.functional as F +from .finetune import Finetune +import numpy as np +from torch.utils.data import DataLoader + +global_max_dist = torch.tensor(0) +global_max_dist2 = torch.tensor(0) +global_lam = 0.25 + + +class DAP(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.network = backbone + self.train_mask = kwargs['train_mask'] + self.task_inc = kwargs['task_inc'] + self.pull_constraint = kwargs['pull_constraint'] + self.pull_constraint_coeff = kwargs['pull_constraint_coeff'] + + self.task_idx = 0 + self.task_data_count = [] + self.prompt_center = None + + # initialize class_mask + if self.num_class % kwargs['task_num'] != 0: + raise ValueError('Number of classes must be divisible by number of tasks') + classes_per_task = self.num_class // kwargs['task_num'] + self.class_mask = [list(range(i * classes_per_task, (i + 1) * classes_per_task)) for i in range(kwargs['task_num'])] + + self.original_model = copy.deepcopy(self.backbone) + self.original_model.to(self.device) + self.original_model.eval() + + if kwargs['freeze']: + # all parameters are frozen for original vit model + for p in self.original_model.parameters(): + p.requires_grad = False + + # freeze args.freeze[blocks, patch_embed, cls_token] parameters + for n, p in self.network.named_parameters(): + if n.startswith(tuple(kwargs['freeze'])): + p.requires_grad = False + + self.loss_fn.to(self.device) + + def observe(self, data, train_gprompt=False, gen=False): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + with torch.no_grad(): + if self.original_model is not None: + output = self.original_model(x) + cls_features = output['pre_logits'] + else: + cls_features = None + if gen: + output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True, gen=gen) + else: + output = self.network(x, task_id=self.task_idx, cls_features=cls_features, train=True) + logits = output['logits'] + + # here is the trick to mask out classes of non-current tasks + if self.train_mask and self.class_mask is not None: + mask = self.class_mask[self.task_idx] + not_mask = np.setdiff1d(np.arange(self.num_class), mask) + not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device) + logits = logits.index_fill( + dim=1, index=not_mask, value=float('-inf')) + + if (train_gprompt): + + pla_similarity_loss_res = self.cal_latestsimilarity_loss( + model=self.network, task_id=self.task_idx) + sta_similarity_loss_res = self.cal_similarity_loss(model=self.network, task_id=self.task_idx, prompt_center=self.prompt_center) + + pla_similarity_loss = pla_similarity_loss_res['similarity'] + sta_similarity_loss = sta_similarity_loss_res['avg_similarity'] + + min_data_count = min(self.task_data_count) + max_data_count = max(self.task_data_count) + last_data_count = self.task_data_count[-1] + epsilon = 1e-10 + alpha = (last_data_count - min_data_count) / (max_data_count - min_data_count + epsilon) + + loss2 = alpha*sta_similarity_loss + loss3 = (1-alpha)*pla_similarity_loss + + loss = self.loss_fn(logits, y) + loss2 + loss3 + + else: + # base criterion (CrossEntropyLoss) + loss = self.loss_fn(logits, y) + if self.pull_constraint and 'reduce_sim' in output: + loss = loss - self.pull_constraint_coeff * output['reduce_sim'] + + if not math.isfinite(loss.item()): + raise RuntimeError(f'Loss is {loss.item()}, stopping training') + + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0), loss + + def inference(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + with torch.no_grad(): + if self.original_model is not None: + output = self.original_model(x) + cls_features = output['pre_logits'] + else: + cls_features = None + output = self.network(x, task_id=self.task_idx, cls_features=cls_features, gen=True) + logits = output['logits'] + + # adding mask to output logits + if self.task_inc and self.class_mask is not None: + mask = self.class_mask[self.task_idx] + mask = torch.tensor(mask, dtype=torch.int64).to(self.device) + logits_mask = torch.ones_like(logits, device=self.device) * float('-inf') + logits_mask = logits_mask.index_fill(1, mask, 0.0) + logits = logits + logits_mask + + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0) + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + self.network.task_id = task_idx + self.task_data_count.append(len(train_loader.dataset)) + + @staticmethod + def cal_latestsimilarity_loss(model: torch.nn.Module, task_id=-1): + res = dict() + global global_max_dist2 + + gprompt = model.prompt.generalprompt + tprompt = model.prompt.taskprompt[task_id].detach() + + gprompt_flat = gprompt.view(-1) + tprompt_tensors = tprompt.view(-1) + similarity = 1-F.cosine_similarity(gprompt_flat, tprompt_tensors, dim=0) + res['similarity'] = similarity + return res + + @staticmethod + def cal_center(model: torch.nn.Module, task_id=-1, task_data_count=None, prompt_center=None): + tprompt = model.prompt.taskprompt + if task_id > 0: + if prompt_center is None: + prompt_center = tprompt[0].detach().view(-1) + current_tprompt = tprompt[task_id - 1].detach().view(-1) + if task_data_count: + weights = [1 / count for count in task_data_count[:task_id]] + normalized_weight = weights[-1] / sum(weights) + weights2 = sum(weights[:-1]) / sum(weights) + else: + normalized_weight = 1.0 / task_id + prompt_center = (prompt_center * weights2) + \ + (current_tprompt * normalized_weight) + else: + prompt_center = torch.zeros_like(tprompt[0].detach().view(-1)) + return prompt_center + + @staticmethod + def cal_similarity_loss(model: torch.nn.Module, task_id=-1, prompt_center=None): + res = dict() + global global_max_dist + + gprompt = model.prompt.generalprompt + + if task_id > 0: + gprompt_flat = gprompt.view(-1) + similarity = 1-F.cosine_similarity(gprompt_flat, prompt_center, dim=0) + res['similarity'] = similarity + res['avg_similarity'] = similarity + else: + res['similarity'] = torch.tensor(0) + res['avg_similarity'] = 0 + return res \ No newline at end of file diff --git a/core/model/der.py b/core/model/der.py new file mode 100644 index 0000000000000000000000000000000000000000..c6348caa4e3127430903bfb30b328bddc9e6fde7 --- /dev/null +++ b/core/model/der.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{DBLP:conf/cvpr/YanX021, + author = {Shipeng Yan and + Jiangwei Xie and + Xuming He}, + title = {{DER:} Dynamically Expandable Representation for Class Incremental + Learning}, + booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR} + 2021, virtual, June 19-25, 2021}, + pages = {3014--3023}, + year = {2021}, +} + +https://openaccess.thecvf.com/content/CVPR2021/papers/Yan_DER_Dynamically_Expandable_Representation_for_Class_Incremental_Learning_CVPR_2021_paper.pdf + +Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/der.py +""" +import math +import copy +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +from .finetune import Finetune +from core.model.backbone import resnet18, resnet34, resnet50 +from core.utils import get_instance + +def get_convnet(convnet_type, pretrained=False): + name = convnet_type.lower() + if name == "resnet18": + dic = {"num_classes": 10, "args":{'dataset':'cifar100'}} + return resnet18(**dic) + # elif name=="resnet32": + # return resnet32() + elif name == "resnet34": + return resnet34() + elif name == "resnet50": + return resnet50() + else: + raise NotImplementedError("Unknown type {}".format(convnet_type)) + +class SimpleLinear(nn.Module): + ''' + Reference: + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py + ''' + def __init__(self, in_features, out_features, bias=True): + super(SimpleLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') + nn.init.constant_(self.bias, 0) + + def forward(self, input): + return {'logits': F.linear(input, self.weight, self.bias)} + +class DER(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.convnets = nn.ModuleList() + self.pretrained = None + self.out_dim = None + self.fc = None + self.aux_fc = None + self.task_sizes = [] + + self.kwargs = kwargs + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.known_cls_num = 0 + self.total_cls_num = 0 + + self.convnet_type = 'resnet18' + + @property + def feature_dim(self): + if self.out_dim is None: + return 0 + return self.out_dim * len(self.convnets) + + def forward(self, x): + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + + out = self.fc(features) # {logics: self.fc(features)} + + aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"] + + out.update({"aux_logits": aux_logits, "features": features}) + return out + """ + { + 'features': features + 'logits': logits + 'aux_logits':aux_logits + } + """ + + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + + logit = self.fc(features)['logits'] + + if self.task_idx == 0: + loss = self.loss_fn(logit, y) + else: + loss_clf = self.loss_fn(logit, y) + aux_targets = y.clone() + aux_targets = torch.where( + aux_targets - self.known_cls_num + 1 > 0, + aux_targets - self.known_cls_num + 1, + 0, + ) + aux_logits = self.aux_fc(features[:, -self.out_dim :])["logits"] + loss_aux = F.cross_entropy(aux_logits, aux_targets) + loss = loss_aux + loss_clf + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss + + def inference(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + features = [convnet(x)["features"] for convnet in self.convnets] + features = torch.cat(features, 1) + logit = self.fc(features)['logits'] + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + def update_fc(self, nb_classes): + if len(self.convnets) == 0: + self.convnets.append(get_convnet(self.convnet_type)) + else: + self.convnets.append(get_convnet(self.convnet_type)) + self.convnets[-1].load_state_dict(self.convnets[-2].state_dict()) + + if self.out_dim is None: + self.out_dim = self.convnets[-1].out_dim + fc = self.generate_fc(self.feature_dim, nb_classes) + if self.fc is not None: + nb_output = self.fc.out_features + weight = copy.deepcopy(self.fc.weight.data) + bias = copy.deepcopy(self.fc.bias.data) + fc.weight.data[:nb_output, : self.feature_dim - self.out_dim] = weight + fc.bias.data[:nb_output] = bias + + del self.fc + self.fc = fc + + new_task_size = nb_classes - sum(self.task_sizes) + self.task_sizes.append(new_task_size) + + self.aux_fc = self.generate_fc(self.out_dim, new_task_size + 1) + + def generate_fc(self, in_dim, out_dim): + fc = SimpleLinear(in_dim, out_dim) + + return fc + + def freeze_convnets(self): + for param in self.convnets.parameters(): + param.requires_grad = False + self.convnets.eval() + + def weight_align(self, increment): + weights = self.fc.weight.data + newnorm = torch.norm(weights[-increment:, :], p=2, dim=1) + oldnorm = torch.norm(weights[:-increment, :], p=2, dim=1) + meannew = torch.mean(newnorm) + meanold = torch.mean(oldnorm) + gamma = meanold / meannew + print("alignweights,gamma=", gamma) + self.fc.weight.data[-increment:, :] *= gamma + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + self.known_cls_num = self.total_cls_num + self.total_cls_num = self.init_cls_num + self.task_idx*self.inc_cls_num + + self.freeze_convnets() + self.update_fc(self.total_cls_num) + self.loss_fn = nn.CrossEntropyLoss() + self.convnets = self.convnets.to(self.device) + self.fc = self.fc.to(self.device) + self.aux_fc = self.aux_fc.to(self.device) + + def _train(self): + self.fc.train() + # iffreeze('fc',self.fc) + self.aux_fc.train() + # iffreeze('auxfc',self.aux_fc) + for i in range(self.task_idx -1): + self.convnets[i].eval() + self.convnets[-1].train() + # for i,cov in enumerate(self.convnets): + # iffreeze(f'cov{i}',cov) + + def get_parameters(self, config): + train_parameters = [] + + train_parameters.append({"params": self.convnets.parameters()}) + + if self.fc is not None: + train_parameters.append({"params": self.fc.parameters()}) + if self.aux_fc is not None: + train_parameters.append({"params": self.aux_fc.parameters()}) + return train_parameters + +def iffreeze(name,net): + for k,v in net.named_parameters(): + print('{}{}: {}'.format(name,k, v.requires_grad)) diff --git a/core/model/dmnsp.py b/core/model/dmnsp.py new file mode 100644 index 0000000000000000000000000000000000000000..45b42be77da5707bfd35ff5b20d8f9141450c1fa --- /dev/null +++ b/core/model/dmnsp.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +""" +TODO: citation + +Adapted from TODO: source +""" + +import math +import torch +import torch.nn as nn +import numpy as np + +from torch import optim +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from tqdm import tqdm + +from .backbone.transformer import ResidualAttentionBlock +from .backbone.clip import tokenize, CLIP +from .backbone.vit import ViTZoo + +VIT = ViTZoo +CLIP = CLIP + +class DMNSP(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.device = device + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.label_smoothing = kwargs['label_smoothing'] + + self._cur_task_id = -1 + self._known_classes = 0 + self.visual_U = [] + self.lamda = [[0 for _ in range(12)] for _ in range(12)] + self.lamda_scale = kwargs['lamda_scale'] + + self.accm_class_names = [] + self.curr_class_names = [] + self.accm_text_tokens = None + self.curr_text_tokens = None + + self.prompt_template = kwargs['prompt_template'] + + self._network = backbone + + for name, param in self._network.named_parameters(): + if 'adapt' not in name: + param.requires_grad = False + + if isinstance(self._network, VIT): + self.visual_transformer_blocks = [module for module in self._network.modules() if isinstance(module, ResidualAttentionBlock)] + + self.classifier_pool = nn.ModuleList([ + nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + + [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)] + ) + + elif isinstance(self._network, CLIP): + self.visual_transformer_blocks = [module for name, module in self._network.named_modules() if isinstance(module, ResidualAttentionBlock) and 'visual' in name] + else: + assert 0 + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + if isinstance(self._network, CLIP): + features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.curr_text_tokens) + elif isinstance(self._network, ViTZoo): + features = self._network(x) + logits_per_img = [] + for prompts in [self.classifier_pool[self._cur_task_id]]: + logits_per_img.append(prompts(features)) + logits_per_img = torch.cat(logits_per_img, dim=1) + + loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing) + + preds = logits_per_img.softmax(dim=-1).argmax(dim=1) + acc = preds.eq(y).sum().item() / y.size(0) + + loss.backward() + + if self._cur_task_id > 0: + + if isinstance(self._network, VIT): + + for name, param in self._network.named_parameters(): + for i in range(12): + if 'adapt' in name and 'down' in name and 'weight' in name: + + v = self.visual_U[i].to(self.device) + v_ = torch.mm(param.grad.data, v) + param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i] + + elif 'adapt' in name and 'up' in name and 'weight' in name: + + v = self.visual_U[i].to(self.device) + v_ = torch.mm(v.T, param.grad.data) + param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i] + + elif isinstance(self._network, CLIP): + + for name, param in self._network.named_parameters(): + for i in range(12): + if 'visual' in name and 'adapt' in name and 'down' in name and 'weight' in name: + + v = self.visual_U[i].to(self.device) + v_ = torch.mm(param.grad.data, v) + param.grad.data = torch.mm(v_, v.T) * self.lamda[int(name.split(".")[3])][i] + + elif 'visual' in name and 'adapt' in name and 'up' in name and 'weight' in name: + + v = self.visual_U[i].to(self.device) + v_ = torch.mm(v.T, param.grad.data) + param.grad.data = torch.mm(v, v_) * self.lamda[int(name.split(".")[3])][i] + + + return preds, acc, loss + + def inference(self, data, task_id = -1): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + if isinstance(self._network, CLIP): + if task_id > -1: + assert self.init_cls_num == self.inc_cls_num + features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens[task_id * self.inc_cls_num : (task_id + 1) * self.inc_cls_num]) + else: + features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens) + elif isinstance(self._network, VIT): + if task_id > -1: + assert 0, 'Not Implemented' + else: + features = self._network(x) + logits_per_img = [] + for prompts in self.classifier_pool[:self._cur_task_id + 1]: + logits_per_img.append(prompts(features)) + logits_per_img = torch.cat(logits_per_img, dim=1) + + preds = logits_per_img.softmax(dim=-1).argmax(dim=1) + + if task_id > -1: + assert self.init_cls_num == self.inc_cls_num + preds += task_id * self.inc_cls_num + + acc = preds.eq(y).sum().item() / y.size(0) + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self._cur_task_id = task_idx + if task_idx == 1: + self._known_classes = self.init_cls_num + elif task_idx > 1: + self._known_classes += self.inc_cls_num + + self.curr_class_names = train_loader.dataset.get_class_names() + self.accm_class_names += self.curr_class_names + + self.curr_text_tokens = tokenize( + [self.prompt_template.format(c) for c in self.curr_class_names] + ).to(self.device) + + self.accm_text_tokens = tokenize( + [self.prompt_template.format(c) for c in self.accm_class_names] + ).to(self.device) + + if task_idx > 0: + for data in train_loader: + x = data['image'].to(self.device) + self._network(x, self.curr_text_tokens, compute_lora_feat=True) # will replace last lora_feat + + for j in range(12): # Number of layers of both vision transformer and text transformer, hardcoded + activation_visual = self.visual_transformer_blocks[j].lora_feature + activation_visual = torch.bmm(activation_visual.permute(1, 2, 0), + activation_visual.permute(1, 0, 2)).sum(dim=0) + U_visual, _, _ = torch.linalg.svd(activation_visual, full_matrices=False) + U_visual = U_visual[:, 0:1] + + for k in range(12): + v_visual = self.visual_U[k] + normalized_vector_visual = U_visual / torch.norm(U_visual) + similarities_visual = [] + + for column_visual in v_visual.t(): + normalized_column_visual = column_visual / torch.norm(column_visual) + cos_sim_visual = torch.dot(normalized_vector_visual.squeeze(), + normalized_column_visual.squeeze()) + similarities_visual.append(cos_sim_visual) + + dot_products_visual = torch.mean(torch.topk(torch.stack(similarities_visual), int(len(similarities_visual) * 00.1))[0]) + self.lamda[j][k] = torch.exp(-dot_products_visual) * self.lamda_scale + + break # first batch only + + @torch.no_grad() + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + for data in train_loader: + x = data['image'].to(self.device) + self._network(x, self.curr_text_tokens, compute_lora_feat=True) # will replace last lora_feat + + for i in range(12): + + activation = self.visual_transformer_blocks[i].lora_feature + + activation = torch.bmm(activation.permute(1, 2, 0), + activation.permute(1, 0, 2)).sum(dim=0) + + U, _, _ = torch.linalg.svd(activation, full_matrices=False) + + if task_idx == 0: + r = 0 + self.visual_U.append(U[:,max(r,1):]) + else: + r = 1 + Ui = torch.cat((self.visual_U[i], U[:, r:]), dim=1) + self.visual_U[i] = Ui + + break # first batch only + + def get_parameters(self, config): + return self._network.parameters() \ No newline at end of file diff --git a/core/model/dualprompt.py b/core/model/dualprompt.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e938c002e27ec65e78a6626d9a37b0774a3ed3 --- /dev/null +++ b/core/model/dualprompt.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{DBLP:conf/eccv/0002ZESZLRSPDP22, + author = {Zifeng Wang and + Zizhao Zhang and + Sayna Ebrahimi and + Ruoxi Sun and + Han Zhang and + Chen{-}Yu Lee and + Xiaoqi Ren and + Guolong Su and + Vincent Perot and + Jennifer G. Dy and + Tomas Pfister}, + editor = {Shai Avidan and + Gabriel J. Brostow and + Moustapha Ciss{\'{e}} and + Giovanni Maria Farinella and + Tal Hassner}, + title = {DualPrompt: Complementary Prompting for Rehearsal-Free Continual Learning}, + booktitle = {Computer Vision - {ECCV} 2022 - 17th European Conference, Tel Aviv, + Israel, October 23-27, 2022, Proceedings, Part {XXVI}}, + volume = {13686}, + pages = {631--648}, + publisher = {Springer}, + year = {2022} +} + +https://arxiv.org/pdf/2204.04799 + +Adapted from https://github.com/GT-RIPL/CODA-Prompt +""" + +import math +import copy +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +from .finetune import Finetune +from core.model.backbone.resnet import * +import numpy as np +from torch.utils.data import DataLoader + + +class Model(nn.Module): + # A model consists with a backbone and a classifier + def __init__(self, backbone, feat_dim, num_class): + super().__init__() + self.backbone = backbone + self.feat_dim = feat_dim + self.num_class = num_class + self.classifier = nn.Linear(feat_dim, num_class) + + def forward(self, x, train=True): + if train: + feat, loss = self.backbone(x, train=True) + return self.classifier(feat), loss + else: + feat = self.backbone(x, train=False) + return self.classifier(feat) + + +class DualPrompt(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num']) + self.network.backbone.create_prompt('dual', n_tasks = kwargs['task_num'], prompt_param=[10, kwargs['e_prompt_length'], kwargs['g_prompt_length']]) + self.task_idx = 0 + self.kwargs = kwargs + + self.last_out_dim = 0 + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + self.network.backbone.task_id = task_idx + + in_features = self.network.classifier.in_features + out_features = self.network.classifier.out_features + new_out_features = self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num'] + new_fc = nn.Linear(in_features, new_out_features) + new_fc.weight.data[:out_features] = self.network.classifier.weight.data + new_fc.bias.data[:out_features] = self.network.classifier.bias.data + self.network.classifier = new_fc + self.network.to(self.device) + + self.loss_fn = nn.CrossEntropyLoss(reduction='none') + + self.out_dim = new_out_features + self.dw_k = torch.tensor(np.ones(self.out_dim + 1, dtype=np.float32)).to(self.device) + + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logit, loss = self.network(x, train=True) + + logit[:,:self.last_out_dim] = -float('inf') + dw_cls = self.dw_k[-1 * torch.ones(y.size()).long()] + + loss += (self.loss_fn(logit, y) * dw_cls).mean() + + pred = torch.argmax(logit, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0), loss + + + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + self.last_out_dim = self.out_dim + + def inference(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + logit = self.network(x, train=False) + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + + def get_parameters(self, config): + return list(self.network.backbone.prompt.parameters()) + list(self.network.classifier.parameters()) diff --git a/core/model/erace.py b/core/model/erace.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d303250f119839603524dce733e129c0b5b19c --- /dev/null +++ b/core/model/erace.py @@ -0,0 +1,130 @@ +""" +@misc{caccia2022new, + title={New Insights on Reducing Abrupt Representation Change in Online Continual Learning}, + author={Lucas Caccia and Rahaf Aljundi and Nader Asadi and Tinne Tuytelaars and Joelle Pineau and Eugene Belilovsky}, + year={2022}, + eprint={2104.05025}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} + +Adapted from https://github.com/pclucas14/AML +""" + +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +class distLinear(nn.Module): + def __init__(self, indim, outdim, weight=None): + super().__init__() + self.L = nn.Linear(indim, outdim, bias = False) + if weight is not None: + self.L.weight.data = Variable(weight) + + self.scale_factor = 10 + + def forward(self, x): + x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x) + x_normalized = x.div(x_norm + 0.00001) + + L_norm = torch.norm(self.L.weight, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data) + cos_dist = torch.mm(x_normalized,self.L.weight.div(L_norm + 0.00001).transpose(0,1)) + + scores = self.scale_factor * (cos_dist) + + return scores + +class Model(nn.Module): + def __init__(self, backbone, num_classes): + super().__init__() + self.backbone = backbone + self.classifier = distLinear(backbone.out_dim, num_classes) + + def forward(self, data): + return self.classifier(self.backbone(data)) + +class ERACE(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.model = Model(backbone, kwargs['num_classes']) + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.use_augs = kwargs['use_augs'] + self.device = device + self.seen_so_far = 0 + + self.task_free = kwargs['task_free'] + assert self.task_free, 'ER-ACE must be task free' + + self.sample_kwargs = { + 'amt': 10, + 'exclude_task': None + } + + self.model.to(self.device) + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + self.inc_data = {'x': x, 'y': y, 't': self.cur_task_idx} + + logits = self.model(x) + + mask = torch.zeros_like(logits) + mask[:, self.seen_so_far:] = 1 + + if self.cur_task_idx > 0 or self.task_free: + logits = logits.masked_fill(mask == 0, -1e9) + + loss = F.cross_entropy(logits, y) + pred = logits.max(1)[1] + correct_count = (pred == y).sum().item() + total_count = y.shape[0] + + if len(self.buffer) > 0 and (self.task_free or self.cur_task_idx > 0): + re_data = self.buffer.sample_random(**self.sample_kwargs) + + re_logits = self.model(re_data['x']) + loss += F.cross_entropy(re_logits, re_data['y']) + re_pred = re_logits.max(1)[1] + correct_count += (re_pred == re_data['y']).sum().item() + total_count += re_data['y'].shape[0] + + acc = correct_count / total_count + + # only return output of incoming data, not including output of rehearsal data + return pred, acc, loss + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + logits = self.model(x) + pred = logits.max(1)[1] + correct_count = pred.eq(y).sum().item() + acc = correct_count / y.size(0) + + return pred, acc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + if not self.use_augs: + train_loader.dataset.trfms = test_loaders[0].dataset.trfms + + self.buffer = buffer + self.buffer.device = self.device + + self.cur_task_idx = task_idx + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + self.seen_so_far = self.init_cls_num + self.inc_cls_num * task_idx + + def add_reservoir(self): + self.buffer.add_reservoir(self.inc_data) + + def get_parameters(self, config): + return self.model.parameters() diff --git a/core/model/eraml.py b/core/model/eraml.py new file mode 100644 index 0000000000000000000000000000000000000000..7195e90d7fe371eeb4dec96cd6250c04a8c563c1 --- /dev/null +++ b/core/model/eraml.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +""" +@misc{caccia2022new, + title={New Insights on Reducing Abrupt Representation Change in Online Continual Learning}, + author={Lucas Caccia and Rahaf Aljundi and Nader Asadi and Tinne Tuytelaars and Joelle Pineau and Eugene Belilovsky}, + year={2022}, + eprint={2104.05025}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} + +Adapted from https://github.com/pclucas14/AML +""" + +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +class distLinear(nn.Module): + def __init__(self, indim, outdim, weight=None): + super().__init__() + self.L = nn.Linear(indim, outdim, bias = False) + if weight is not None: + self.L.weight.data = Variable(weight) + + self.scale_factor = 10 + + def forward(self, x): + x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x) + x_normalized = x.div(x_norm + 0.00001) + + L_norm = torch.norm(self.L.weight, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data) + cos_dist = torch.mm(x_normalized,self.L.weight.div(L_norm + 0.00001).transpose(0,1)) + + scores = self.scale_factor * (cos_dist) + + return scores + +class Model(nn.Module): + def __init__(self, backbone, num_classes): + super().__init__() + self.backbone = backbone + self.classifier = distLinear(backbone.out_dim, num_classes) + + def return_hidden(self, data): + return self.backbone(data) + + def forward(self, data): + return self.classifier(self.backbone(data)) + +class ERAML(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.model = Model(backbone, kwargs['num_classes']) + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.use_augs = kwargs['use_augs'] + self.supcon_temperature = kwargs['supcon_temperature'] + self.use_minimal_selection = kwargs['use_minimal_selection'] + self.task_free = kwargs['task_free'] + self.device = device + + self.sample_kwargs = { + 'amt': 10, + 'exclude_task': None + } + + self.model.to(self.device) + + def normalize(self, x): + x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) + x_normalized = x.div(x_norm + 0.00001) + return x_normalized + + def sup_con_loss(self, anchor_feature, features, anch_labels=None, labels=None, + mask=None, temperature=0.1, base_temperature=0.07): + + batch_size, anchor_count, _ = features.shape + + labels = labels.contiguous().view(-1, 1) + anch_labels = anch_labels.contiguous().view(-1, 1) + mask = torch.eq(anch_labels, labels.T).float().to(self.device) + + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) # hid_all + + # compute logits + anchor_dot_contrast = torch.div(anchor_feature @ contrast_feature.T, temperature) + + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, anchor_count) + + # compute log_prob + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + + # loss + loss = - (temperature / base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss + + def process_inc(self, inc_data): + """ get loss from incoming data """ + + x, y = inc_data['x'], inc_data['y'] + + logits = self.model(x) + pred = logits.max(1)[1] + + # If task_based, see if task id >= 1 + # If task_free, see if buffer has something + if inc_data['t'] > 0 or (self.task_free and len(self.buffer) > 0): + pos_x, neg_x, pos_y, neg_y, invalid_idx, _ = self.sample( + inc_data, + task_free = self.task_free, + same_task_neg = True # If true, neg sample can only choose from inc_data, instead of inc_data + buffer + ) + + hidden = self.model.return_hidden(inc_data['x']) + hidden_norm = self.normalize(hidden[~invalid_idx]) + + all_xs = torch.cat((pos_x, neg_x)) + all_hid = self.normalize(self.model.return_hidden(all_xs)) + all_hid = all_hid.reshape(2, pos_x.size(0), -1) + pos_hid, neg_hid = all_hid[:, ~invalid_idx] + + loss = 0. + if (~invalid_idx).any(): + inc_y = y[~invalid_idx] + pos_y = pos_y[~invalid_idx] + neg_y = neg_y[~invalid_idx] + hid_all = torch.cat((pos_hid, neg_hid), dim=0) + y_all = torch.cat((pos_y, neg_y), dim=0) + + loss = self.sup_con_loss( + labels=y_all, + features=hid_all.unsqueeze(1), + anch_labels=inc_y.repeat(2), + anchor_feature=hidden_norm.repeat(2, 1), + temperature=self.supcon_temperature + ) + + else: + # do regular training at the start + loss = F.cross_entropy(logits, y) + + correct_count = (pred == y).sum().item() + + return pred, correct_count, loss + + def observe(self, data): + + inc_correct_counts, inc_total_counts, re_correct_counts, re_total_counts = 0, 0, 0, 0 + + x, y = data['image'].to(self.device), data['label'].to(self.device) + self.inc_data = {'x': x, 'y': y, 't': self.cur_task_idx} + + pred, correct_count, loss = self.process_inc(self.inc_data) + total_count = y.shape[0] + + if len(self.buffer) > 0 and (self.task_free or self.cur_task_idx > 0): + re_data = self.buffer.sample(**self.sample_kwargs) + + re_logits = self.model(re_data['x']) + loss += F.cross_entropy(re_logits, re_data['y']) + re_pred = re_logits.max(1)[1] + correct_count += (re_pred == re_data['y']).sum().item() + total_count += re_data['y'].shape[0] + + acc = correct_count / total_count + + return pred, acc, loss + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + logits = self.model(x) + pred = logits.max(1)[1] + correct_count = pred.eq(y).sum().item() + acc = correct_count / y.size(0) + + return pred, acc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + if not self.use_augs: + train_loader.dataset.trfms = test_loaders[0].dataset.trfms + + self.buffer = buffer + self.buffer.device = self.device + if self.use_minimal_selection: + self.sample = self.buffer.sample_minimal_pos_neg + else: + self.sample = self.buffer.sample_pos_neg + + self.cur_task_idx = task_idx + + def add_reservoir(self): + self.buffer.add(self.inc_data) + + def get_parameters(self, config): + return self.model.parameters() diff --git a/core/model/ewc.py b/core/model/ewc.py new file mode 100644 index 0000000000000000000000000000000000000000..018565cc8f2d9858f4563d8902c64b43e34cfcb1 --- /dev/null +++ b/core/model/ewc.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +""" +@article{DBLP:journals/corr/KirkpatrickPRVD16, + author = {James Kirkpatrick and + Razvan Pascanu and + Neil C. Rabinowitz and + Joel Veness and + Guillaume Desjardins and + Andrei A. Rusu and + Kieran Milan and + John Quan and + Tiago Ramalho and + Agnieszka Grabska{-}Barwinska and + Demis Hassabis and + Claudia Clopath and + Dharshan Kumaran and + Raia Hadsell}, + title = {Overcoming catastrophic forgetting in neural networks}, + journal = {CoRR}, + volume = {abs/1612.00796}, + year = {2016} +} + +https://arxiv.org/abs/1612.00796 + +Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py +""" + + +import math +import copy +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +from .finetune import Finetune +from core.model.backbone.resnet import * +import numpy as np +from torch.utils.data import DataLoader +from torch import optim + + +class Model(nn.Module): + # A model consists with a backbone and a classifier + def __init__(self, backbone, feat_dim, num_class): + super().__init__() + self.backbone = backbone + self.feat_dim = feat_dim + self.num_class = num_class + self.classifier = nn.Linear(feat_dim, num_class) + + def forward(self, x): + return self.get_logits(x) + + def get_logits(self, x): + logits = self.classifier(self.backbone(x)['features']) + return logits + +class EWC(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num']) + + self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters() + if p.requires_grad} + self.fisher = {n: torch.zeros(p.shape).to(self.device) for n, p in self.network.named_parameters() + if p.requires_grad} + self.lamda = self.kwargs['lamda'] + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + in_features = self.network.classifier.in_features + out_features = self.network.classifier.out_features + + new_fc = nn.Linear(in_features, self.kwargs['init_cls_num'] + task_idx * self.kwargs['inc_cls_num']) + new_fc.weight.data[:out_features] = self.network.classifier.weight.data + new_fc.bias.data[:out_features] = self.network.classifier.bias.data + self.network.classifier = new_fc + self.network.to(self.device) + + def observe(self, data): + x, y = data['image'].to(self.device), data['label'].to(self.device) + logit = self.network(x) + + if self.task_idx == 0: + loss = F.cross_entropy(logit, y) + else: + + + + old_classes = self.network.classifier.out_features - self.kwargs['inc_cls_num'] + + #print(old_classes) + #print(logit[:, old_classes:].shape) + #print(y) + #print(y-old_classes) + + loss = F.cross_entropy(logit[:, old_classes:], y - old_classes) + loss += self.lamda * self.compute_ewc() + + pred = torch.argmax(logit, dim=1) + + #print(pred) + #print(y) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + """ + Args: + task_idx (int): The index of the current task. + buffer: Buffer object used in previous tasks. + train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset. + test_loaders (list of DataLoader): List of dataloaders for the test datasets. + + Code Reference: + https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py + https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py + """ + + # record the parameters + self.ref_param = {n: p.clone().detach() for n, p in self.network.named_parameters() + if p.requires_grad} + # the shape of new fisher is changed + new_fisher = self.getFisher(train_loader) + # using growing alpha + alpha = 1 - self.kwargs['inc_cls_num']/self.network.classifier.out_features + for n, p in self.fisher.items(): + new_fisher[n][:len(self.fisher[n])] = alpha * p + (1 - alpha) * new_fisher[n][:len(self.fisher[n])] + + self.fisher = new_fisher + + def inference(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + logit = self.network(x) + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + def getFisher(self, train_loader): + """ + Compute the Fisher Information Matrix for the parameters of the network. + + Args: + train_loader (torch.utils.data.DataLoader): Dataloader for the training dataset. + + Returns: + dict: Dictionary of Fisher Information Matrices for each parameter. + + Code Reference: + https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py + https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py + """ + def accumulate(fisher): + """ + Accumulate the squared gradients for the Fisher Information Matrix. + + Args: + fisher (dict): Dictionary containing the current Fisher Information matrices. + + Returns: + dict: Updated Fisher Information matrices. + """ + for n, p in self.network.named_parameters(): + if p.grad is not None and n in fisher.keys(): + fisher[n] += p.grad.pow(2).clone() * len(y) + return fisher + + # Initialize Fisher Information matrices with zeros + fisher = { + n: torch.zeros_like(p).to(self.device) for n, p in self.network.named_parameters() + if p.requires_grad + } + + self.network.train() + optimizer = optim.SGD(self.network.parameters(), lr=0.1) + + loss_fn = torch.nn.CrossEntropyLoss() + # Iterate over the training data + for data in train_loader: + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + logits = self.network(x) + loss = loss_fn(logits, y) + + optimizer.zero_grad() + loss.backward() + + # Accumulate Fisher Information + fisher = accumulate(fisher) + + # Normalize Fisher Information matrices by the number of samples + num_samples = train_loader.batch_size * len(train_loader) + for n, p in fisher.items(): + fisher[n] = p / num_samples + return fisher + + def compute_ewc(self): + """ + Compute the Elastic Weight Consolidation (EWC) loss. + + This function calculates the EWC loss based on the stored Fisher Information matrices + and reference parameters from a previous task. + + References: + - https://github.com/G-U-N/PyCIL/blob/master/models/ewc.py + - https://github.com/mmasana/FACIL/blob/master/src/approach/ewc.py + + Returns: + torch.Tensor: The computed EWC loss. + """ + loss = 0 + for n, p in self.network.named_parameters(): + if n in self.fisher.keys(): + loss += torch.sum(self.fisher[n] * (p[:len(self.ref_param[n])] - self.ref_param[n]).pow(2)) / 2 + return loss + + def get_parameters(self, config): + train_parameters = [] + train_parameters.append({"params": self.network.parameters()}) + return train_parameters \ No newline at end of file diff --git a/core/model/finetune.py b/core/model/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..9c45f0966181835b7d5aaee74325c6fdb18b47f9 --- /dev/null +++ b/core/model/finetune.py @@ -0,0 +1,52 @@ +import torch +from torch import nn + +class Finetune(nn.Module): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__() + self.backbone = backbone + self.feat_dim = feat_dim + self.num_class = num_class + self.classifier = nn.Linear(feat_dim, num_class) + self.loss_fn = nn.CrossEntropyLoss(reduction='mean') + self.device = kwargs['device'] + self.kwargs = kwargs + + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logit = self.classifier(self.backbone(x)['features']) + loss = self.loss_fn(logit, y) + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss + + def inference(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + logit = self.classifier(self.backbone(x)['features']) + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + def forward(self, x): + return self.classifier(self.backbone(x)['features']) + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + pass + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + pass + + def get_parameters(self, config): + train_parameters = [] + train_parameters.append({"params": self.backbone.parameters()}) + train_parameters.append({"params": self.classifier.parameters()}) + return train_parameters + diff --git a/core/model/gpm.py b/core/model/gpm.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3422ee0bb88284e09286eb8a3b4217524c3d3a --- /dev/null +++ b/core/model/gpm.py @@ -0,0 +1,207 @@ +""" +@inproceedings{ + saha2021gradient, + title={Gradient Projection Memory for Continual Learning}, + author={Gobinda Saha and Isha Garg and Kaushik Roy}, + booktitle={International Conference on Learning Representations}, + year={2021}, + url={https://openreview.net/forum?id=3AOj0RCNC2} +} + +Code Reference: +https://github.com/sahagobinda/GPM +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from .backbone.alexnet import Conv2d_TRGP, Linear_TRGP + +class Network(nn.Module): + + def __init__(self, backbone, **kwargs): + + super().__init__() + self.backbone = backbone + + self.classifiers = nn.ModuleList([ + nn.Linear(backbone.feat_dim, kwargs['init_cls_num'], bias = False)] + + [nn.Linear(backbone.feat_dim, kwargs['inc_cls_num'], bias = False) for _ in range(kwargs['task_num'] - 1)] + ) + + def forward(self, data, compute_input_matrix = False): + + logits = [] + image_features = self.backbone(data, compute_input_matrix) + for classifier in self.classifiers: + logits.append(classifier(image_features)) + + return logits + +class GPM(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + self.network = Network(backbone, **kwargs) + self.device = device + + self.task_num = kwargs["task_num"] + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self._known_classes = 0 + + self.feature_list = [] + self.feature_mat = [] + + self.layers = [] # 3 Conv2d, Then 2 Linear + for module in self.network.modules(): + if isinstance(module, Conv2d_TRGP) or isinstance(module, Linear_TRGP): + self.layers.append(module) + + self.network.to(self.device) + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + logits = self.network(x) + loss = F.cross_entropy(logits[self.cur_task], y) + + preds = logits[self.cur_task].max(1)[1] + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + loss.backward() + + if self.cur_task > 0: + for i, module in enumerate(self.layers): + sz = module.weight.grad.data.shape[0] + module.weight.grad.data = module.weight.grad.data - (module.weight.grad.data.view(sz,-1) @ self.feature_mat[i]).view(module.weight.shape) + + return preds, acc, loss + + def inference(self, data, task_id = -1): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + # Task-Aware (Task-Incremetanl Scenario) + if task_id > -1: + + if task_id == 0: + bias_classes = 0 + elif task_id == 1: + bias_classes = self.init_cls_num + else: + bias_classes = self.init_cls_num + (task_id - 1) * self.inc_cls_num + + logits = self.network(x) + preds = logits[task_id].max(1)[1] + bias_classes + + # Task-Agnostic (Class-Incremetanl Scenario) + else: + + logits = torch.cat(self.network(x), dim=-1) + preds = logits.max(1)[1] + + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self.cur_task = task_idx + + if task_idx == 1: + self._known_classes += self.init_cls_num + elif task_idx > 1: + self._known_classes += self.inc_cls_num + + if task_idx > 0: + + self.feature_mat = [torch.tensor(feat @ feat.T, dtype=torch.float32, device=self.device) for feat in self.feature_list] + + for name, param in self.network.named_parameters(): + param.requires_grad_(True) + if 'bn' in name: + param.requires_grad_(False) + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + x = [] + for batch in train_loader: + x.append(batch['image'].to(self.device)) + + x = torch.cat(x, dim = 0) + + # hardcoded, choose 125 input from it + indices = torch.randperm(x.size(0)) + selected_indices = indices[:125] + x = x[selected_indices] + + self.network.eval() + self.network(x, compute_input_matrix = True) + + batch_list = [2*12,100,100] + ksize = [4, 3, 2] # kernel size of each conv layer + conv_output_size = [29, 12, 5] # output size of each conv layer + in_channel = [3, 64, 128] # input channel of each conv layer + + mat_list = [] # representation (activation) of each layer + + for i, module in enumerate(self.layers): + + if isinstance(module, Conv2d_TRGP): + bsz, ksz, s, inc = batch_list[i], ksize[i], conv_output_size[i], in_channel[i] + + # act is the input of each layer (both conv and linear) + mat = np.zeros((ksz * ksz * inc, s * s * bsz)) + act = module.input_matrix.detach().cpu().numpy() + + k = 0 + for kk in range(bsz): + for ii in range(s): + for jj in range(s): + mat[:,k]=act[kk, :, ii:ksz+ii, jj:ksz+jj].reshape(-1) + k += 1 + + mat_list.append(mat) + elif isinstance(module, Linear_TRGP): + mat_list.append(module.input_matrix.detach().cpu().numpy().T) + + threshold = 0.97 + task_idx * 0.003 + + # get the space for each layer + if task_idx == 0: + for i, activation in enumerate(mat_list): + + U, S, _ = np.linalg.svd(activation, full_matrices = False) + # criteria (Eq-5) + sval_total = (S**2).sum() + sval_ratio = (S**2)/sval_total + r = np.sum(np.cumsum(sval_ratio) < threshold) + + self.feature_list.append(U[:, :r]) + else: + for i, activation in enumerate(mat_list): + + _, S, _ = np.linalg.svd(activation, full_matrices = False) + sval_total = (S**2).sum() + + act_hat = activation - self.feature_list[i] @ self.feature_list[i].T @ activation + U, S, _ = np.linalg.svd(act_hat, full_matrices=False) + sval_hat = (S**2).sum() + sval_ratio = (S**2)/sval_total + accumulated_sval = (sval_total-sval_hat)/sval_total + + if accumulated_sval >= threshold: + print (f'Skip Updating GPM for layer: {i+1}') + else: + r = np.sum(np.cumsum(sval_ratio) + accumulated_sval < threshold) + 1 + Ui = np.hstack((self.feature_list[i], U[:, :r])) + self.feature_list[i] = Ui[:, :min(Ui.shape[0], Ui.shape[1])] + + def get_parameters(self, config): + return self.network.parameters() \ No newline at end of file diff --git a/core/model/icarl.py b/core/model/icarl.py new file mode 100644 index 0000000000000000000000000000000000000000..a7dfd62256869b63c4c875e7c49076b447933925 --- /dev/null +++ b/core/model/icarl.py @@ -0,0 +1,287 @@ +""" +@inproceedings{rebuffi2017icarl, + title={icarl: Incremental classifier and representation learning}, + author={Rebuffi, Sylvestre-Alvise and Kolesnikov, Alexander and Sperl, Georg and Lampert, Christoph H}, + booktitle={Proceedings of the IEEE conference on Computer Vision and Pattern Recognition}, + pages={2001--2010}, + year={2017} +} +https://arxiv.org/abs/1611.07725 +""" + +from typing import Iterator +import torch +from torch import nn +from torch.nn import functional as F +from copy import deepcopy +import numpy as np +from torch.nn.parameter import Parameter +from torch.utils.data import DataLoader, Dataset +import PIL +import os +import copy + +class Model(nn.Module): + # A model consists with a backbone and a classifier + def __init__(self, backbone, feat_dim, num_class): + super().__init__() + self.backbone = backbone + self.feat_dim = feat_dim + self.num_class = num_class + self.classifier = nn.Linear(feat_dim, num_class) + + def forward(self, x): + return self.get_logits(x) + + def get_logits(self, x): + logits = self.classifier(self.backbone(x)['features']) + return logits + + + +class ICarl(nn.Module): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__() + + # device setting + self.device = kwargs['device'] + + # current task index + self.cur_task_id = 0 + + # current task class indexes + self.cur_cls_indexes = None + + # Build model structure + self.network = Model(backbone, feat_dim, num_class) + + # Store old network + self.old_network = None + + # the previous class num before this task + self.prev_cls_num = 0 + + # the total class num containing this task + self.accu_cls_num = 0 + + + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.task_num = kwargs['task_num'] + + # class prototype vector + self.class_means = None + + + # only the current model is optimized + def get_parameters(self, config): + return self.network.parameters() + + + def observe(self, data): + # get data and labels + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + # compute logits and loss + logits, loss = self.criterion(x, y) + + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0), loss + + + def inference(self, data): + + # if self.class_means is not None: + # print(len(self.class_means), self.accu_cls_num) + + if self.class_means is not None and len(self.class_means) == self.accu_cls_num: + # we only test when class mean vector computation is finished. + return self.NCM_classify(data) + + else: + # class mean vector for this task have not computed yet, + # call this function after func "after_task" called, + # and return value of this "inference" function is computed + # via model forward logits + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + logits = self.network(x)[:, :self.accu_cls_num] + pred = torch.argmax(logits, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + + + def NCM_classify(self, data): + + def metric(x, y): + """Calculate the pair-wise euclidean distance between input tensor `x` and `y`. + Args: + x (Tensor): to be calculated for distance, with shape (N, D) + y (Tensor): to be calculated for distance, with shape (M, D), where D is embedding size. + + Returns: + pair euclidean distance tensor with shape (N, M) + and dist[i][j] represent the distance between x[i] and y[j] + """ + n = x.size(0) + m = y.size(0) + x = x.unsqueeze(1).expand(n, m, -1) + y = y.unsqueeze(0).expand(n, m, -1) + return torch.pow(x - y, 2).sum(2) # (N, M) + + # using NCM + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + feats = feats = self.network.backbone(x)['features'] + feats = feats.view(feats.size(0), -1) + distance = metric(feats, self.class_means) + + pred = torch.argmin(distance, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0) + + + def forward(self, x): + return self.network(x)[:, self.accu_cls_num] + + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + if self.cur_task_id == 0: + self.accu_cls_num = self.init_cls_num + else: + self.accu_cls_num += self.inc_cls_num + + self.cur_cls_indexes = np.arange(self.prev_cls_num, self.accu_cls_num) + + + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + # freeze old network as KD teacher + + self.old_network = copy.deepcopy(self.network) + self.old_network.eval() + + self.prev_cls_num = self.accu_cls_num + + # update buffer + buffer.reduce_old_data(self.cur_task_id, self.accu_cls_num) + + + val_transform = test_loaders[0].dataset.trfms + buffer.update(self.network, train_loader, val_transform, + self.cur_task_id, self.accu_cls_num, self.cur_cls_indexes, + self.device) + + # compute class mean vector via samples in buffer + self.class_means = self.calc_class_mean(buffer, + train_loader, + val_transform, + self.device).to(self.device) + self.cur_task_id += 1 + + + + + + def criterion(self, x, y): + def _KD_loss(pred, soft, T=2): + """ + Compute the knowledge distillation (KD) loss between the predicted logits and the soft target. + Code Reference: + KD loss function is borrowed from: https://github.com/G-U-N/PyCIL/blob/master/models/icarl.py + """ + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] + + cur_logits = self.network(x)[:, :self.accu_cls_num] + loss_clf = F.cross_entropy(cur_logits, y) + + if self.cur_task_id > 0: + old_logits = self.old_network(x) + loss_kd = _KD_loss( + cur_logits[:, : self.prev_cls_num], + old_logits[:, : self.prev_cls_num], + ) + loss = loss_clf + loss_kd + else: + loss = loss_clf + + return cur_logits, loss + + + + + def calc_class_mean(self, buffer, train_loader, val_transform, device): + + # mini dataset simulating all samples in the buffer + class miniBufferDataset(Dataset): + def __init__(self, root, mode, image_list, label_list, transforms): + self.data_root = root + self.mode = mode + self.images = image_list + self.labels = label_list + self.transforms = transforms + + def __getitem__(self, idx): + img_path = self.images[idx] + label = self.labels[idx] + image = PIL.Image.open(os.path.join(self.data_root, self.mode, img_path)).convert("RGB") + image = self.transforms(image) + return {"image": image, "label": label} + + def __len__(self): + return len(self.labels) + + root_path = train_loader.dataset.data_root + mode = train_loader.dataset.mode + image_list = buffer.images + label_list = buffer.labels + ds = miniBufferDataset(root_path, mode, image_list, label_list, val_transform) + + icarl_loader = DataLoader(ds, + batch_size=train_loader.batch_size, + shuffle=False, + num_workers=train_loader.num_workers, + pin_memory=train_loader.pin_memory) + + + # compute features for all training samples + extracted_features = [] + extracted_targets = [] + with torch.no_grad(): + self.network.eval() + for data in icarl_loader: + images = data['image'].to(device) + labels = data['label'].to(device) + feats = self.network.backbone(images)['features'] + # normalize + extracted_features.append(feats / feats.norm(dim=1).view(-1, 1)) + extracted_targets.extend(labels) + + extracted_features = torch.cat(extracted_features).cpu() + extracted_targets = torch.stack(extracted_targets).cpu() + + all_class_means = [] + for curr_cls in np.unique(extracted_targets): + # get all indices from current class + cls_ind = np.where(extracted_targets == curr_cls)[0] + # get all extracted features for current class + cls_feats = extracted_features[cls_ind] + # add the exemplars to the set and normalize + cls_feats_mean = cls_feats.mean(0) / cls_feats.mean(0).norm() + + all_class_means.append(cls_feats_mean) + + return torch.stack(all_class_means) diff --git a/core/model/l2p.py b/core/model/l2p.py new file mode 100644 index 0000000000000000000000000000000000000000..074077027998d2a9f7b1504a1c7c707c88f304b0 --- /dev/null +++ b/core/model/l2p.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{DBLP:conf/cvpr/0002ZL0SRSPDP22, + author = {Zifeng Wang and + Zizhao Zhang and + Chen{-}Yu Lee and + Han Zhang and + Ruoxi Sun and + Xiaoqi Ren and + Guolong Su and + Vincent Perot and + Jennifer G. Dy and + Tomas Pfister}, + title = {Learning to Prompt for Continual Learning}, + booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, + {CVPR} 2022, New Orleans, LA, USA, June 18-24, 2022}, + pages = {139--149}, + publisher = {{IEEE}}, + year = {2022} +} + +https://arxiv.org/abs/2112.08654 + +Adapted from https://github.com/GT-RIPL/CODA-Prompt +""" + +import math +import copy +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from core.model.backbone.resnet import * + +class Model(nn.Module): + def __init__(self, backbone, embed_dim, total_cls_num): + super().__init__() + self.backbone = backbone + self.classifier = nn.Linear(embed_dim, total_cls_num, bias=True) + + def forward(self, x, train=True): + feat, reduce_sim = self.backbone(x, train=train) + return self.classifier(feat), reduce_sim + +class L2P(nn.Module): + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.device = device + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.total_cls_num = kwargs['num_class'] + self.task_num = kwargs['task_num'] + self.embed_dim = kwargs['feat_dim'] + self.pull_constraint_coeff = kwargs['pull_constraint_coeff'] + self.cur_task_id = 0 + self._known_classes = 0 + + self.network = Model(backbone, self.embed_dim, self.total_cls_num) + self.network.backbone.create_prompt( + prompt_flag = 'l2p', + length = kwargs['prompt_length'], # L_p + prompt_init = nn.init.uniform_, + pool_size = kwargs['pool_size'], # M + top_k = kwargs['top_k'], # N + num_layers = 1, + embed_dim = self.embed_dim + ) + self.network.to(self.device) + + self.unfrezeed_params = [] + for name, param in self.network.named_parameters(): + param.requires_grad_(False) + if 'prompt' in name or 'classifier' in name: + param.requires_grad_(True) + self.unfrezeed_params.append(param) + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self.cur_task_id = task_idx + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + logits, reduce_sim = self.network(x, train=True) + + if self.cur_task_id == 0: + mask = np.arange(self.init_cls_num) + else: + mask = np.arange(self.inc_cls_num) + self._known_classes + + not_mask = np.setdiff1d(np.arange(self.total_cls_num), mask) + not_mask = torch.tensor(not_mask, dtype=torch.int64).to(self.device) + logits = logits.index_fill(dim=1, index=not_mask, value=float('-inf')) + + loss = F.cross_entropy(logits, y) - self.pull_constraint_coeff * reduce_sim + + loss.backward() + torch.nn.utils.clip_grad_norm_(self.unfrezeed_params, 1.0) + + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() / x.size(0) + + return pred, acc, loss + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + logits, _ = self.network(x, train=False) + + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() / x.size(0) + return pred, acc + + def get_parameters(self, config): + + return self.unfrezeed_params diff --git a/core/model/lora_sub.py b/core/model/lora_sub.py new file mode 100644 index 0000000000000000000000000000000000000000..22ddadae678205ceca00f06423f5718cf4c4a196 --- /dev/null +++ b/core/model/lora_sub.py @@ -0,0 +1,432 @@ +""" +@inproceedings{liu2025lora, + title={LoRA Subtraction for Drift-Resistant Space in Exemplar-Free Continual Learning}, + author={Liu, Xuan and Chang, Xiaobin}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2025} +} + +Adapted from https://github.com/scarlet0703/LoRA-Sub-DRS +""" + +import torch +import torch.nn as nn +import copy +import numpy as np +import math + +from copy import deepcopy +from torch.optim.optimizer import Optimizer +from torch.nn import functional as F +from tqdm import tqdm +from collections import defaultdict +from scipy.spatial.distance import cdist + +from .backbone.transformer import MultiHeadAttention_LoRA_Sub + +class AugmentedTripletLoss(nn.Module): + def __init__(self, margin=1.0, norm=2): + super(AugmentedTripletLoss, self).__init__() + self.margin = margin + self.norm = norm + self.ranking_loss = nn.MarginRankingLoss(margin=margin) + + def forward(self, inputs, targets, center): + device = (torch.device('cuda') + if inputs.is_cuda + else torch.device('cpu')) + n = inputs.size(0) # batch_size + + # Compute pairwise distance, replace by the official when merged + dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) + dist = dist + dist.t() + dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) + + dist = dist.clamp(min=1e-12).sqrt() # for numerical stability + + # For each anchor, find the hardest positive and negative + mask = targets.expand(n, n).eq(targets.expand(n, n).t()) + num_proto = len(center) + dist_ap, dist_an = [], [] + for i in range(n): + dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) + if dist[i][mask[i] == 0].numel() == 0: + dist_an.append((dist[i][mask[i]].max()+self.margin).unsqueeze(0)) + else: + dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) + dist_ap = torch.cat(dist_ap) + if num_proto > 0: + center = torch.from_numpy(center / np.linalg.norm(center, axis=1)[:, None]).to(device) + for i in range(n): + for j in range(num_proto): + distp = torch.norm(inputs[i].unsqueeze(0) - center[j], self.norm).clamp(min=1e-12) + dist_an[i] = min(dist_an[i].squeeze(0), distp).unsqueeze(0) + dist_an = torch.cat(dist_an) + # Compute ranking hinge loss + y = torch.ones_like(dist_an) + loss = self.ranking_loss(dist_an, dist_ap, y) + return loss + +class Adam(Optimizer): + r"""Implements Adam algorithm. + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, svd=False, thres=1.001, + weight_decay=0, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, svd=svd, + thres=thres) + super(Adam, self).__init__(params, defaults) + + self.eigens = defaultdict(dict) + self.transforms = defaultdict(dict) + + def __setstate__(self, state): + super(Adam, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + group.setdefault('svd', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + svd = group['svd'] + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse gradients, please consider SparseAdam instead') + + update = self.get_update(group, grad, p) + + if svd and len(self.transforms) > 0: + if len(update.shape) == 4: + # the transpose of the manuscript + update_ = torch.mm(update.view(update.size( + 0), -1), self.transforms[p]).view_as(update) + else: + if self.transforms[p].shape[0]==update.shape[0]: + update_ = torch.mm(self.transforms[p], update) + else: + update_ = torch.mm(update, self.transforms[p]) + else: + update_ = update + + p.data.add_(update_) + return loss + + def get_transforms(self): + for group in self.param_groups: + svd = group['svd'] + if svd is False: + continue + + for p in group['params']: + thres = group['thres'] + if p.requires_grad == False or thres == 1.0: + continue + eigen_values = self.eigens[p]['eigen_value'] + cumulative_sum = eigen_values.cumsum(dim=0) / eigen_values.sum() + num_vectors = (cumulative_sum >= thres).nonzero(as_tuple=True)[0][0] + 1 + print('reserving basis {}/{}; cond: {}, ratio:{}'.format( + num_vectors, eigen_values.shape[0], + eigen_values[0] / eigen_values[-1], + cumulative_sum[num_vectors - 1] + )) + basis = self.eigens[p]['eigen_vector'][:, :num_vectors] + transform = torch.mm(basis, basis.transpose(1, 0)) + self.transforms[p] = transform / torch.norm(transform) + self.transforms[p].detach_() + + def get_eigens(self, fea_in): + + for group in self.param_groups: + if group['svd']: + for p in group['params']: + if p.requires_grad: + eigen = self.eigens[p] + _, eigen_value, eigen_vector = torch.svd(fea_in[p], some=False) + eigen['eigen_value'] = eigen_value + eigen['eigen_vector'] = eigen_vector + + def get_update(self, group, grad, p): + amsgrad = group['amsgrad'] + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * \ + math.sqrt(bias_correction2) / bias_correction1 + update = - step_size * exp_avg / denom + return update + +class Model(nn.Module): + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self._cur_task_id = -1 + self.backbone = backbone + self.device = device + self.classifier_pool = nn.ModuleList([ + nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + + [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)] + ) + + def update_fc(self): + + self._cur_task_id += 1 + + def update_input_matrix(self, x): + + self.backbone(x, get_input_matrix = True) + + def extract_features(self, x): + return self.backbone(x) + + def forward(self, x): + + logits = [] + features = self.backbone(x) + + for prompts in [self.classifier_pool[self._cur_task_id]]: + logits.append(prompts(features)) + + return { + 'logits': torch.cat(logits, dim=1), + 'features': features + } + +class LoRAsub_DRS(nn.Module): + + def __init__(self, backbone, device, **kwargs): + + super().__init__() + + self.device = device + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.task_num = kwargs["task_num"] + self.fc_lrate = kwargs["fc_lrate"] + self.margin_inter = kwargs["margin_inter"] + self.lambada = kwargs["lambada"] + self._known_classes = 0 + self._total_classes = 0 + self._cur_task = 0 + + self._network = Model(backbone, device, **kwargs) + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_LoRA_Sub)] + self.criterion = AugmentedTripletLoss(margin=self.margin_inter).to(self.device) + self._protos = [] + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + outputs = self._network(x) + logits, features = outputs['logits'], outputs['features'] + + ATL = self.criterion( + features / features.norm(dim=-1, keepdim=True), + y, + self._protos + ) + loss = F.cross_entropy(logits, y) + self.lambada * ATL + + preds = logits.max(1)[1] + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc, loss + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + features = self._network.extract_features(x) + features = (features.T / (np.linalg.norm(features.T, axis=0) + 1e-8)).T + + class_means = self._protos / np.linalg.norm(self._protos, axis=1)[:, None] + + dists = cdist(class_means, features, 'sqeuclidean') + scores = dists.T + + #preds = np.argsort(scores, axis=1)[:, :1] + preds = np.argmin(scores, axis=1) + + correct_count = (preds == y.cpu().numpy()).sum() + acc = correct_count / y.size(0) + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self._known_classes = self._total_classes + self._total_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + self._network.update_fc() + self._network = self._network.to(self.device) + + for module in self.attention_modules: + module.init_param() + + unfrezeed_params = [] + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if f'classifier_pool.{self._cur_task}.' in name or \ + f'lora'in name: + param.requires_grad_(True) + unfrezeed_params.append(name) + + print(f"Current task : {task_idx}, Parameters to be updated: {len(unfrezeed_params)}") + + if task_idx > 0: + for batch in tqdm(train_loader, desc="Forwarding to get input matrix"): + self._network.update_input_matrix(x = batch['image'].to(self.device)) + + self.fea_in = {} + + for module in self.attention_modules: + self.fea_in[module.lora_A_k.weight] = deepcopy(module.cur_matrix).to(self.device) + self.fea_in[module.lora_A_v.weight] = deepcopy(module.cur_matrix).to(self.device) + self.fea_in[module.lora_B_k.weight] = deepcopy(module.cur_matrix).to(self.device) + self.fea_in[module.lora_B_v.weight] = deepcopy(module.cur_matrix).to(self.device) + module.reset_input_matrix() + + @torch.no_grad() + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + for module in self.attention_modules: + module.save_weight() + + # Build Proto + for class_idx in range(self._known_classes, self._total_classes): + + inputs_list = [] + + for batch in train_loader: + x, y = batch['image'].to(self.device), batch['label'].to(self.device) + inputs_list.append(x[y == class_idx]) + + class_inputs = torch.cat(inputs_list, dim=0) + features_list = [] + + for start_idx in range(0, class_inputs.shape[0], 128): + end_idx = min(start_idx + 128, class_inputs.shape[0]) + batch_inputs = class_inputs[start_idx:end_idx].to(self.device) + feats = self._network.extract_features(batch_inputs) + features_list.append(feats.detach().cpu().numpy()) + + features = np.concatenate(features_list, axis=0) + class_mean = np.mean(features, axis=0) + self._protos.append(class_mean) + + assert len(self._protos) > 0 + + self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + self._cur_task += 1 + + def get_parameters(self, config): + return self._network.parameters() + + def get_optimizer(self, lr, weight_decay): + + fea_params = [] + for module in self.attention_modules: + fea_params.append(module.lora_A_k.weight) + fea_params.append(module.lora_A_v.weight) + fea_params.append(module.lora_B_k.weight) + fea_params.append(module.lora_B_v.weight) + + cls_params = [ + self._network.classifier_pool[self._cur_task].weight, + self._network.classifier_pool[self._cur_task].bias, + ] + + model_optimizer_arg = {'params': [{'params': fea_params, 'svd': True, 'lr': lr, + 'thres': 0.99}, + {'params': cls_params, 'weight_decay': weight_decay, + 'lr': self.fc_lrate}], + 'weight_decay': weight_decay, + 'betas': (0.9, 0.999) + } + + optim = Adam(**model_optimizer_arg) + + if self._cur_task > 0: + optim.get_eigens(self.fea_in) + optim.get_transforms() + + return optim \ No newline at end of file diff --git a/core/model/lucir.py b/core/model/lucir.py new file mode 100644 index 0000000000000000000000000000000000000000..bbdcd4b6d42e7c60dcec15a21e43652b0d54bd9b --- /dev/null +++ b/core/model/lucir.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{DBLP:conf/cvpr/HouPLWL19, + title = {Learning a Unified Classifier Incrementally via Rebalancing}, + author = {Saihui Hou and + Xinyu Pan and + Chen Change Loy and + Zilei Wang and + Dahua Lin}, + booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR} + 2019, Long Beach, CA, USA, June 16-20, 2019}, + pages = {831--839}, + publisher = {Computer Vision Foundation / {IEEE}}, + year = {2019} +} +https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html + +Adapted from https://github.com/hshustc/CVPR19_Incremental_Learning +""" + +import math +import copy +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +from .finetune import Finetune +from core.model.backbone.resnet import * +import numpy as np +from torch.utils.data import DataLoader + + +cur_features = [] +ref_features = [] +old_scores = [] +new_scores = [] +def get_ref_features(self, inputs, outputs): + global ref_features + ref_features = inputs[0] + +def get_cur_features(self, inputs, outputs): + global cur_features + cur_features = inputs[0] + +def get_old_scores_before_scale(self, inputs, outputs): + global old_scores + old_scores = outputs + +def get_new_scores_before_scale(self, inputs, outputs): + global new_scores + new_scores = outputs + + + +class Model(nn.Module): + # A model consists with a backbone and a classifier + def __init__(self, backbone, feat_dim, num_class): + super().__init__() + self.backbone = backbone + self.feat_dim = feat_dim + self.num_class = num_class + self.classifier = CosineLinear(feat_dim, num_class) + + def forward(self, x): + return self.get_logits(x) + + def get_logits(self, x): + logits = self.classifier(self.backbone(x)['features']) + return logits + + + +class LUCIR(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.network = Model(self.backbone, feat_dim, kwargs['init_cls_num']) + self.K = kwargs['K'] + self.lw_mr = kwargs['lw_mr'] + self.ref_model = None + self.task_idx = 0 + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + + if task_idx == 1: + self.ref_model = copy.deepcopy(self.network) + in_features = self.network.classifier.in_features + out_features = self.network.classifier.out_features + new_fc = SplitCosineLinear(in_features, out_features, self.kwargs['inc_cls_num']) + new_fc.fc1.weight.data = self.network.classifier.weight.data + new_fc.sigma.data = self.network.classifier.sigma.data + self.network.classifier = new_fc + lamda_mult = out_features*1.0 / self.kwargs['inc_cls_num'] + + + elif task_idx > 1: + self.ref_model = copy.deepcopy(self.network) + in_features = self.network.classifier.in_features + out_features1 = self.network.classifier.fc1.out_features + out_features2 = self.network.classifier.fc2.out_features + new_fc = SplitCosineLinear(in_features, out_features1+out_features2, self.kwargs['inc_cls_num']).to(self.device) + new_fc.fc1.weight.data[:out_features1] = self.network.classifier.fc1.weight.data + new_fc.fc1.weight.data[out_features1:] = self.network.classifier.fc2.weight.data + new_fc.sigma.data = self.network.classifier.sigma.data + self.network.classifier = new_fc + lamda_mult = (out_features1+out_features2)*1.0 / (self.kwargs['inc_cls_num']) + + if task_idx > 0: + self.cur_lamda = self.kwargs['lamda'] * math.sqrt(lamda_mult) + else: + self.cur_lamda = self.kwargs['lamda'] + + self._init_new_fc(task_idx, buffer, train_loader) + + if task_idx == 0: + self.loss_fn = nn.CrossEntropyLoss() + else: + self.loss_fn1 = nn.CosineEmbeddingLoss() + self.loss_fn2 = nn.CrossEntropyLoss() + self.loss_fn3 = nn.MarginRankingLoss(margin=self.kwargs['dist']) + + self.ref_model.eval() + self.num_old_classes = self.ref_model.classifier.out_features + self.handle_ref_features = self.ref_model.classifier.register_forward_hook(get_ref_features) + self.handle_cur_features = self.network.classifier.register_forward_hook(get_cur_features) + self.handle_old_scores_bs = self.network.classifier.fc1.register_forward_hook(get_old_scores_before_scale) + self.handle_new_scores_bs = self.network.classifier.fc2.register_forward_hook(get_new_scores_before_scale) + + self.network = self.network.to(self.device) + if self.ref_model is not None: + self.ref_model = self.ref_model.to(self.device) + + def _init_new_fc(self, task_idx, buffer, train_loader): + if task_idx == 0: + return + old_embedding_norm = self.network.classifier.fc1.weight.data.norm(dim=1, keepdim=True) + average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type(torch.DoubleTensor) + feature_model = self.network.backbone + num_features = self.network.classifier.in_features + novel_embedding = torch.zeros((self.kwargs['inc_cls_num'], num_features)) + + tmp_datasets = copy.deepcopy(train_loader.dataset) + for cls_idx in range(self.network.classifier.fc1.out_features, self.network.classifier.fc1.out_features + self.network.classifier.fc2.out_features): + cls_dataset = train_loader.dataset + task_data, task_target = cls_dataset.images, cls_dataset.labels + cls_indices = np.where(np.array(task_target) == cls_idx) # tuple + cls_data, cls_target = np.array([task_data[i] for i in cls_indices[0]]), np.array([task_target[i] for i in cls_indices[0]]) + tmp_datasets.images = cls_data + tmp_datasets.labels = cls_target + tmp_loader = DataLoader(tmp_datasets, batch_size=128, shuffle=False, num_workers=2) + num_samples = cls_data.shape[0] + cls_features = self._compute_feature(feature_model, tmp_loader, num_samples, num_features) + norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1) + cls_embedding = torch.mean(norm_features, dim=0) + novel_embedding[cls_idx-self.network.classifier.fc1.out_features] = F.normalize(cls_embedding, p=2, dim=0) * average_old_embedding_norm + + self.network.to(self.device) + self.network.classifier.fc2.weight.data = novel_embedding.to(self.device) + + def _compute_feature(self, feature_model, loader, num_samples, num_features): + feature_model.eval() + features = np.zeros([num_samples, num_features]) + start_idx = 0 + with torch.no_grad(): + for batch_idx, batch in enumerate(loader): + inputs, labels = batch['image'], batch['label'] + inputs = inputs.to(self.device) + features[start_idx:start_idx+inputs.shape[0], :] = np.squeeze(feature_model.feature(inputs).cpu()) + start_idx = start_idx+inputs.shape[0] + assert(start_idx==num_samples) + return features + + + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logit = self.network(x) + + if self.task_idx == 0: + loss = self.loss_fn(logit, y) + else: + ref_outputs = self.ref_model(x) + loss = self.loss_fn1(cur_features, ref_features.detach(), \ + torch.ones(x.size(0)).to(self.device)) * self.cur_lamda + + loss += self.loss_fn2(logit, y) + + outputs_bs = torch.cat((old_scores, new_scores), dim=1) + assert(outputs_bs.size()==logit.size()) + gt_index = torch.zeros(outputs_bs.size()).to(self.device) + gt_index = gt_index.scatter(1, y.view(-1,1), 1).ge(0.5) + gt_scores = outputs_bs.masked_select(gt_index) + max_novel_scores = outputs_bs[:, self.num_old_classes:].topk(self.K, dim=1)[0] + hard_index = y.lt(self.num_old_classes) + hard_num = torch.nonzero(hard_index).size(0) + + if hard_num > 0: + gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, self.K) + max_novel_scores = max_novel_scores[hard_index] + assert(gt_scores.size() == max_novel_scores.size()) + assert(gt_scores.size(0) == hard_num) + loss += self.loss_fn3(gt_scores.view(-1, 1), \ + max_novel_scores.view(-1, 1), torch.ones(hard_num*self.K, 1).to(self.device)) * self.lw_mr + + pred = torch.argmax(logit, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0), loss + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + if self.task_idx > 0: + self.handle_ref_features.remove() + self.handle_cur_features.remove() + self.handle_old_scores_bs.remove() + self.handle_new_scores_bs.remove() + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + logit = self.network(x) + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + def get_parameters(self, config): + if self.task_idx > 0: + #fix the embedding of old classes + ignored_params = list(map(id, self.network.classifier.fc1.parameters())) + base_params = filter(lambda p: id(p) not in ignored_params, \ + self.network.parameters()) + tg_params =[{'params': base_params, 'lr': 0.1, 'weight_decay': 5e-4}, \ + {'params': self.network.classifier.fc1.parameters(), 'lr': 0, 'weight_decay': 0}] + else: + tg_params = self.network.parameters() + + return tg_params \ No newline at end of file diff --git a/core/model/lwf.py b/core/model/lwf.py new file mode 100644 index 0000000000000000000000000000000000000000..88c2a651233ea00c17fa525a8b5b9ca39b08d115 --- /dev/null +++ b/core/model/lwf.py @@ -0,0 +1,81 @@ +import math +import copy +import torch +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +from .finetune import Finetune + +class LWF(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.feat_dim = feat_dim + self.classifier = nn.Linear(self.feat_dim, kwargs['init_cls_num']) + self.old_fc = None + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.known_cls_num = 0 + self.total_cls_num = 0 + self.old_backbone = None + + def freeze(self,nn): + for param in nn.parameters(): + param.requires_grad = False + nn.eval() + return nn + + def update_fc(self): + fc = nn.Linear(self.feat_dim, self.total_cls_num).to(self.device) + if self.classifier is not None: + # del self.old_fc + self.old_fc = self.freeze(copy.deepcopy(self.classifier)) + old_out = self.classifier.out_features + weight = copy.deepcopy(self.classifier.weight.data) + bias = copy.deepcopy(self.classifier.bias.data) + fc.weight.data[:old_out] = weight + fc.bias.data[:old_out] = bias + + # del self.classifier + self.classifier = fc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + self.known_cls_num = self.total_cls_num + self.total_cls_num = self.init_cls_num + self.task_idx*self.inc_cls_num + self.update_fc() + self.loss_fn = nn.CrossEntropyLoss() + if task_idx != 0: + self.old_backbone = self.freeze(copy.deepcopy(self.backbone)).to(self.device) + + + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logit = self.classifier(self.backbone(x)['features']) + + if self.task_idx == 0: + loss = self.loss_fn(logit, y) + else: + fake_targets = y - self.known_cls_num + loss_clf = self.loss_fn(logit[:,self.known_cls_num:],fake_targets) + loss_kd = self._KD_loss(logit[:,:self.known_cls_num],self.old_fc(self.old_backbone(x)['features']),T=2) + lamda = 3 + loss = lamda*loss_kd + loss_clf + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + pass + + def _KD_loss(self, pred, soft, T): + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] + def _cross_entropy(self, pre, logit): + loss = None + return loss diff --git a/core/model/moe_adapter4cl.py b/core/model/moe_adapter4cl.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0274db5130f9c077c237d95c2ac46f32d99253 --- /dev/null +++ b/core/model/moe_adapter4cl.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{yu2024boosting, + title={Boosting continual learning of vision-language models via mixture-of-experts adapters}, + author={Yu, Jiazuo and Zhuge, Yunzhi and Zhang, Lu and Hu, Ping and Wang, Dong and Lu, Huchuan and He, You}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={23219--23230}, + year={2024} +} + +Adapted from https://github.com/JiazuoYu/MoE-Adapters4CL +""" + +import math +import torch +import torch.nn as nn +import numpy as np + +from torch import optim +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from tqdm import tqdm + +from .backbone.clip import tokenize, CLIP +from .backbone.vit import ViTZoo + +VIT = ViTZoo +CLIP = CLIP + +class MOE_ADAPTER4CL(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.device = device + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.label_smoothing = kwargs['label_smoothing'] + + self._known_classes = 0 + self._cur_task_id = -1 + + self.accm_class_names = [] + self.curr_class_names = [] + self.accm_text_tokens = None + self.curr_text_tokens = None + + self.prompt_template = kwargs['prompt_template'] + + self._network = backbone + + self.classifier_pool = nn.ModuleList([ + nn.Linear(kwargs["embd_dim"], kwargs['init_cls_num'], bias=True)] + + [nn.Linear(kwargs["embd_dim"], kwargs['inc_cls_num'], bias=True) for _ in range(kwargs['task_num'] - 1)] + ) + + for name, param in self._network.named_parameters(): + if 'adaptmlp' not in name and 'router' not in name and 'noise' not in name: + param.requires_grad = False + + def observe(self, data): + ''' + Called during the training phase, it inputs a batch of training examples and returns the prediction, accuracy, and forward loss. + ''' + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + if isinstance(self._network, CLIP): + features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.curr_text_tokens) + elif isinstance(self._network, VIT): + features = self._network(x) + logits_per_img = [] + for prompts in [self.classifier_pool[self._cur_task_id]]: + logits_per_img.append(prompts(features)) + logits_per_img = torch.cat(logits_per_img, dim=1) + else: + raise NotImplementedError + + loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing) + + preds = logits_per_img.softmax(dim=-1).argmax(dim=1) + acc = preds.eq(y).sum().item() / y.size(0) + + return preds, acc, loss + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + if isinstance(self._network, CLIP): + features_img, features_txt, logits_per_img, logits_per_txt = self._network(x, self.accm_text_tokens) + elif isinstance(self._network, VIT): + features = self._network(x) + logits_per_img = [] + for prompts in self.classifier_pool[:self._cur_task_id + 1]: + logits_per_img.append(prompts(features)) + logits_per_img = torch.cat(logits_per_img, dim=1) + else: + raise NotImplementedError + + preds = logits_per_img.softmax(dim=-1).argmax(dim=1) + acc = preds.eq(y).sum().item() / y.size(0) + + return preds, acc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self._cur_task_id = task_idx + if task_idx == 1: + self._known_classes = self.init_cls_num + elif task_idx > 1: + self._known_classes += self.inc_cls_num + + self.curr_class_names = train_loader.dataset.get_class_names() + self.accm_class_names += self.curr_class_names + + self.curr_text_tokens = tokenize([self.prompt_template.format(c) for c in self.curr_class_names]).to(self.device) + self.accm_text_tokens = tokenize([self.prompt_template.format(c) for c in self.accm_class_names]).to(self.device) + + def get_parameters(self, config): + return self._network.parameters() \ No newline at end of file diff --git a/core/model/ocm.py b/core/model/ocm.py new file mode 100644 index 0000000000000000000000000000000000000000..5891e0db3fa5f68e9b5d20a3746c53ba5bbb7ee1 --- /dev/null +++ b/core/model/ocm.py @@ -0,0 +1,1019 @@ +""" +@inproceedings{guo2022online, + title={Online continual learning through mutual information maximization}, + author={Guo, Yiduo and Liu, Bing and Zhao, Dongyan}, + booktitle={International Conference on Machine Learning}, + pages={8109--8126}, + year={2022}, + organization={PMLR} +} +https://proceedings.mlr.press/v162/guo22g.html + +Code Reference: +https://github.com/gydpku/OCM/blob/main/test_cifar10.py + +We referred to the original author's code implementation and performed structural refactoring. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from core.model.buffer.onlinebuffer import OnlineBuffer +import math +import numbers +import numpy as np +from torch.autograd import Function +import torch.distributed as dist +import diffdist.functional as distops +from torchvision import transforms + +if torch.__version__ >= '1.4.0': + kwargs = {'align_corners': False} +else: + kwargs = {} + +# ---------------- +import math +import numbers +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function + +if torch.__version__ >= '1.4.0': + kwargs = {'align_corners': False} +else: + kwargs = {} + + +def rgb2hsv(rgb): + """Convert a 4-d RGB tensor to the HSV counterpart. + + Here, we compute hue using atan2() based on the definition in [1], + instead of using the common lookup table approach as in [2, 3]. + Those values agree when the angle is a multiple of 30°, + otherwise they may differ at most ~1.2°. + + References + [1] https://en.wikipedia.org/wiki/Hue + [2] https://www.rapidtables.com/convert/color/rgb-to-hsv.html + [3] https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L212 + """ + + r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] + + Cmax = rgb.max(1)[0] + Cmin = rgb.min(1)[0] + delta = Cmax - Cmin + + hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b) + hue = (hue % (2 * math.pi)) / (2 * math.pi) + saturate = delta / Cmax + value = Cmax + hsv = torch.stack([hue, saturate, value], dim=1) + hsv[~torch.isfinite(hsv)] = 0. + return hsv + + +def hsv2rgb(hsv): + """Convert a 4-d HSV tensor to the RGB counterpart. + + >>> %timeit hsv2rgb(hsv) + 2.37 ms ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) + >>> %timeit rgb2hsv_fast(rgb) + 298 µs ± 542 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) + >>> torch.allclose(hsv2rgb(hsv), hsv2rgb_fast(hsv), atol=1e-6) + True + + References + [1] https://en.wikipedia.org/wiki/HSL_and_HSV#HSV_to_RGB_alternative + """ + h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]] + c = v * s + + n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1) + k = (n + h * 6) % 6 + t = torch.min(k, 4 - k) + t = torch.clamp(t, 0, 1) + + return v - c * t + + +class RandomResizedCropLayer(nn.Module): + def __init__(self, size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)): + ''' + Inception Crop + size (tuple): size of fowarding image (C, W, H) + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + ''' + super(RandomResizedCropLayer, self).__init__() + + _eye = torch.eye(2, 3) + self.size = size + self.register_buffer('_eye', _eye) + self.scale = scale + self.ratio = ratio + + def forward(self, inputs, whbias=None): + _device = inputs.device + N = inputs.size(0) + _theta = self._eye.repeat(N, 1, 1) + + if whbias is None: + whbias = self._sample_latent(inputs) + + _theta[:, 0, 0] = whbias[:, 0] + _theta[:, 1, 1] = whbias[:, 1] + _theta[:, 0, 2] = whbias[:, 2] + _theta[:, 1, 2] = whbias[:, 3] + + grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) + output = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) + + #if self.size is not None: + # output = F.adaptive_avg_pool2d(output, self.size) + + return output#再次仿射取样,——theta考虑whbias + + def _clamp(self, whbias): + + w = whbias[:, 0] + h = whbias[:, 1] + w_bias = whbias[:, 2] + h_bias = whbias[:, 3] + + # Clamp with scale + w = torch.clamp(w, *self.scale) + h = torch.clamp(h, *self.scale) + + # Clamp with ratio + w = self.ratio[0] * h + torch.relu(w - self.ratio[0] * h) + w = self.ratio[1] * h - torch.relu(self.ratio[1] * h - w) + + # Clamp with bias range: w_bias \in (w - 1, 1 - w), h_bias \in (h - 1, 1 - h) + w_bias = w - 1 + torch.relu(w_bias - w + 1) + w_bias = 1 - w - torch.relu(1 - w - w_bias) + + h_bias = h - 1 + torch.relu(h_bias - h + 1) + h_bias = 1 - h - torch.relu(1 - h - h_bias) + + whbias = torch.stack([w, h, w_bias, h_bias], dim=0).t() + + return whbias + + def _sample_latent(self, inputs): + + _device = inputs.device + N, _, width, height = inputs.shape + + # N * 10 trial + area = width * height + target_area = np.random.uniform(*self.scale, N * 10) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = np.exp(np.random.uniform(*log_ratio, N * 10)) + + # If doesn't satisfy ratio condition, then do central crop + w = np.round(np.sqrt(target_area * aspect_ratio)) + h = np.round(np.sqrt(target_area / aspect_ratio)) + cond = (0 < w) * (w <= width) * (0 < h) * (h <= height) + w = w[cond] + h = h[cond] + cond_len = w.shape[0] + if cond_len >= N: + w = w[:N] + h = h[:N] + else: + w = np.concatenate([w, np.ones(N - cond_len) * width]) + h = np.concatenate([h, np.ones(N - cond_len) * height]) + + w_bias = np.random.randint(w - width, width - w + 1) / width + h_bias = np.random.randint(h - height, height - h + 1) / height + w = w / width + h = h / height + + whbias = np.column_stack([w, h, w_bias, h_bias]) + whbias = torch.tensor(whbias, device=_device) + + return whbias + + +class HorizontalFlipRandomCrop(nn.Module): + def __init__(self, max_range): + super(HorizontalFlipRandomCrop, self).__init__() + self.max_range = max_range + _eye = torch.eye(2, 3) + self.register_buffer('_eye', _eye) + + def forward(self, input, sign=None, bias=None, rotation=None): + _device = input.device + N = input.size(0) + _theta = self._eye.repeat(N, 1, 1) + + if sign is None: + sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1 + if bias is None: + bias = torch.empty((N, 2), device=_device).uniform_(-self.max_range, self.max_range) + _theta[:, 0, 0] = sign + _theta[:, :, 2] = bias + + if rotation is not None: + _theta[:, 0:2, 0:2] = rotation + + grid = F.affine_grid(_theta, input.size(), **kwargs).to(_device) + output = F.grid_sample(input, grid, padding_mode='reflection', **kwargs) + + return output + + def _sample_latent(self, N, device=None): + sign = torch.bernoulli(torch.ones(N, device=device) * 0.5) * 2 - 1 + bias = torch.empty((N, 2), device=device).uniform_(-self.max_range, self.max_range) + return sign, bias + + +class Rotation(nn.Module): + def __init__(self, max_range = 4): + super(Rotation, self).__init__() + self.max_range = max_range + self.prob = 0.5 + + def forward(self, input, aug_index=None): + _device = input.device + #print(self.prob) + _, _, H, W = input.size() + + if aug_index is None: + aug_index = np.random.randint(4)#随机四个里生成一个数 + + output = torch.rot90(input, aug_index, (2, 3))#如果是aug》0,从y轴转向x轴,转90*aug,反之亦然。(2,3)是要转的维度 + + _prob = input.new_full((input.size(0),), self.prob)#产生一个inputsize大小,值为0.5的tensor,不会加在a上,直接给prob + _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)#按照prob中p用beinoulli生成0/1值,实际上是每个样本是否输出的mask + output = _mask * input + (1-_mask) * output#这样做要么是原图像,要么旋转90*aug + + else: + aug_index = aug_index % self.max_range + output = torch.rot90(input, aug_index, (2, 3))#旋转角度不mask,原样返回 + + return output + + +class CutPerm(nn.Module): + def __init__(self, max_range = 4): + super(CutPerm, self).__init__() + self.max_range = max_range + self.prob = 0.5 + + def forward(self, input, aug_index=None): + _device = input.device + + _, _, H, W = input.size() + + if aug_index is None: + aug_index = np.random.randint(4) + + output = self._cutperm(input, aug_index) + + _prob = input.new_full((input.size(0),), self.prob) + _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) + output = _mask * input + (1 - _mask) * output + + else: + aug_index = aug_index % self.max_range + output = self._cutperm(input, aug_index) + + return output + + def _cutperm(self, inputs, aug_index): + + _, _, H, W = inputs.size() + h_mid = int(H / 2) + w_mid = int(W / 2) + + jigsaw_h = aug_index // 2 + jigsaw_v = aug_index % 2 + + if jigsaw_h == 1: + inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2) + if jigsaw_v == 1: + inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3) + + return inputs + + +class HorizontalFlipLayer(nn.Module): + def __init__(self): + """ + img_size : (int, int, int) + Height and width must be powers of 2. E.g. (32, 32, 1) or + (64, 128, 3). Last number indicates number of channels, e.g. 1 for + grayscale or 3 for RGB + """ + super(HorizontalFlipLayer, self).__init__() + + _eye = torch.eye(2, 3)#对角矩阵取前两行 + self.register_buffer('_eye', _eye) + + def forward(self, inputs): + _device = inputs.device + + N = inputs.size(0)#batch——size + _theta = self._eye.repeat(N, 1, 1)#重复N份,拼一起 + r_sign = torch.bernoulli(torch.ones(N, device=_device) * 0.5) * 2 - 1#0.5概率生成mask + _theta[:, 0, 0] = r_sign#把mask加入 + grid = F.affine_grid(_theta, inputs.size(), **kwargs).to(_device) + inputs = F.grid_sample(inputs, grid, padding_mode='reflection', **kwargs) + + return inputs#做一系列仿射变换,得到图像 + + +class RandomColorGrayLayer(nn.Module): + def __init__(self, p): + super(RandomColorGrayLayer, self).__init__() + self.prob = p#0.2 + + _weight = torch.tensor([[0.299, 0.587, 0.114]]) + self.register_buffer('_weight', _weight.view(1, 3, 1, 1)) + + def forward(self, inputs, aug_index=None): + + if aug_index == 0: + return inputs + + l = F.conv2d(inputs, self._weight)#卷积处理,只有一个轨道了 + gray = torch.cat([l, l, l], dim=1)#通道扩增3倍,得到原来的大小 + + if aug_index is None: + _prob = inputs.new_full((inputs.size(0),), self.prob) + _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1) + + gray = inputs * (1 - _mask) + gray * _mask + + return gray + + +class ColorJitterLayer(nn.Module): + def __init__(self, p, brightness, contrast, saturation, hue): + super(ColorJitterLayer, self).__init__() + self.prob = p#0.8 + self.brightness = self._check_input(brightness, 'brightness')#[0.6,1.4] + self.contrast = self._check_input(contrast, 'contrast')#[0.6,1.4] + self.saturation = self._check_input(saturation, 'saturation')#[0.6,1.4] + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False)#hue 0.8,return[-0.1,0.1] + + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - value, center + value]#hue[-0.1,0.1] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + def adjust_contrast(self, x): + if self.contrast: + factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)# + means = torch.mean(x, dim=[2, 3], keepdim=True)#【batch——size,3,1,1】 + x = (x - means) * factor + means#【32】【3】每个先减去对应means,再【32】乘以一个【0.6到1.4】中对应数,然后加(1-factor)*means 也是对应【32】加 + return torch.clamp(x, 0, 1)#维持在0,1中 + + def adjust_hsv(self, x): + f_h = x.new_zeros(x.size(0), 1, 1) + f_s = x.new_ones(x.size(0), 1, 1) + f_v = x.new_ones(x.size(0), 1, 1)#生成(batch_size,1,1)的0/1矩阵 + + if self.hue: + f_h.uniform_(*self.hue)#生成【batch_size,1,1】其中值在-0.1,0.1之间 + if self.saturation: + f_s = f_s.uniform_(*self.saturation)#同事,值在0.6到1.4之间 + if self.brightness: + f_v = f_v.uniform_(*self.brightness) + + return RandomHSVFunction.apply(x, f_h, f_s, f_v)#对每个通道做一些随机HSV变化 + + def transform(self, inputs): + # Shuffle transform + if np.random.rand() > 0.5: + transforms = [self.adjust_contrast, self.adjust_hsv] + else: + transforms = [self.adjust_hsv, self.adjust_contrast] + + for t in transforms: + inputs = t(inputs)#对input随机套两个组合比较是必须的 + + return inputs + + def forward(self, inputs): + _prob = inputs.new_full((inputs.size(0),), self.prob) + _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)#生成mask + return inputs * (1 - _mask) + self.transform(inputs) * _mask + + +class RandomHSVFunction(Function): + @staticmethod + def forward(ctx, x, f_h, f_s, f_v): + # ctx is a context object that can be used to stash information + # for backward computation + x = rgb2hsv(x)#从 hsv tensor 变 RGB tensor + h = x[:, 0, :, :]#第一个通道【32,32,32】 + h += (f_h * 255. / 360.)#给每个在【32】中的值加f_h*255/360 对应的那个位置的值 + h = (h % 1)#求余数 + x[:, 0, :, :] = h#第一个通道这样,加法然后取余 + x[:, 1, :, :] = x[:, 1, :, :] * f_s#这里只是乘 + x[:, 2, :, :] = x[:, 2, :, :] * f_v + x = torch.clamp(x, 0, 1)#裁剪,超过0,1范围的变0/1 + x = hsv2rgb(x)#返回 + return x + + @staticmethod + def backward(ctx, grad_output): + # We return as many input gradients as there were arguments. + # Gradients of non-Tensor arguments to forward must be None. + grad_input = None + if ctx.needs_input_grad[0]: + grad_input = grad_output.clone() + return grad_input, None, None, None + + +class NormalizeLayer(nn.Module): + """ + In order to certify radii in original coordinates rather than standardized coordinates, we + add the Gaussian noise _before_ standardizing, which is why we have standardization be the first + layer of the classifier rather than as a part of preprocessing as is typical. + """ + + def __init__(self): + super(NormalizeLayer, self).__init__() + + def forward(self, inputs): + return (inputs - 0.5) / 0.5 + +import torch +from torch import Tensor +from torchvision.transforms.functional import to_pil_image, to_tensor +from torch.nn.functional import conv2d, pad as torch_pad +from typing import Any, List, Sequence, Optional +import numbers +import numpy as np +import torch +from PIL import Image +from typing import Tuple + +class GaussianBlur(torch.nn.Module): + """Blurs image with randomly chosen Gaussian blur. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., C, H, W] shape, where ... means an arbitrary number of leading + dimensions + Args: + kernel_size (int or sequence): Size of the Gaussian kernel. + sigma (float or tuple of float (min, max)): Standard deviation to be used for + creating kernel to perform blurring. If float, sigma is fixed. If it is tuple + of float (min, max), sigma is chosen uniformly at random to lie in the + given range. + Returns: + PIL Image or Tensor: Gaussian blurred version of the input image. + """ + + def __init__(self, kernel_size, sigma=(0.1, 2.0)): + super().__init__() + self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") + for ks in self.kernel_size: + if ks <= 0 or ks % 2 == 0: + raise ValueError("Kernel size value should be an odd and positive number.") + + if isinstance(sigma, numbers.Number): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = (sigma, sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0. < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise ValueError("sigma should be a single number or a list/tuple with length 2.") + + self.sigma = sigma + + @staticmethod + def get_params(sigma_min: float, sigma_max: float) -> float: + """Choose sigma for random gaussian blurring. + Args: + sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel. + sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel. + Returns: + float: Standard deviation to be passed to calculate kernel for gaussian blurring. + """ + return torch.empty(1).uniform_(sigma_min, sigma_max).item() + + def forward(self, img: Tensor) -> Tensor: + """ + Args: + img (PIL Image or Tensor): image to be blurred. + Returns: + PIL Image or Tensor: Gaussian blurred image + """ + sigma = self.get_params(self.sigma[0], self.sigma[1]) + return gaussian_blur(img, self.kernel_size, [sigma, sigma]) + + def __repr__(self): + s = '(kernel_size={}, '.format(self.kernel_size) + s += 'sigma={})'.format(self.sigma) + return self.__class__.__name__ + s + +@torch.jit.unused +def _is_pil_image(img: Any) -> bool: + return isinstance(img, Image.Image) +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size +def _is_tensor_a_torch_image(x: Tensor) -> bool: + return x.ndim >= 2 +def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + kernel1d = pdf / pdf.sum() + + return kernel1d + +def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]: + need_squeeze = False + # make image NCHW + if img.ndim < 4: + img = img.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = img.dtype + need_cast = False + if out_dtype != req_dtype: + need_cast = True + img = img.to(req_dtype) + return img, need_cast, need_squeeze, out_dtype +def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype): + if need_squeeze: + img = img.squeeze(dim=0) + + if need_cast: + # it is better to round before cast + img = torch.round(img).to(out_dtype) + + return img +def _get_gaussian_kernel2d( + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device +) -> Tensor: + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) + kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) + return kernel2d +def _gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: + """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel. + .. warning:: + Module ``transforms.functional_tensor`` is private and should not be used in user application. + Please, consider instead using methods from `transforms.functional` module. + Args: + img (Tensor): Image to be blurred + kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``. + sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``. + Returns: + Tensor: An image that is blurred using gaussian kernel of given parameters + """ + if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)): + raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) + kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype) + + # padding = (left, right, top, bottom) + padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] + img = torch_pad(img, padding, mode="reflect") + img = conv2d(img, kernel, groups=img.shape[-3]) + + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) + return img + +def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: + """Performs Gaussian blurring on the img by given kernel. + The image can be a PIL Image or a Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + Args: + img (PIL Image or Tensor): Image to be blurred + kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers + like ``(kx, ky)`` or a single integer for square kernels. + In torchscript mode kernel_size as single int is not supported, use a tuple or + list of length 1: ``[ksize, ]``. + sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a + sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the + same sigma in both X/Y directions. If None, then it is computed using + ``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``. + Default, None. In torchscript mode sigma as single float is + not supported, use a tuple or list of length 1: ``[sigma, ]``. + Returns: + PIL Image or Tensor: Gaussian Blurred version of the image. + """ + if not isinstance(kernel_size, (int, list, tuple)): + raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size))) + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + if len(kernel_size) != 2: + raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size))) + for ksize in kernel_size: + if ksize % 2 == 0 or ksize < 0: + raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size)) + + if sigma is None: + sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] + + if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): + raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma))) + if isinstance(sigma, (int, float)): + sigma = [float(sigma), float(sigma)] + if isinstance(sigma, (list, tuple)) and len(sigma) == 1: + sigma = [sigma[0], sigma[0]] + if len(sigma) != 2: + raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma))) + for s in sigma: + if s <= 0.: + raise ValueError('sigma should have positive values. Got {}'.format(sigma)) + + t_img = img + if not isinstance(img, torch.Tensor): + if not _is_pil_image(img): + raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img))) + + t_img = to_tensor(img) + + output = _gaussian_blur(t_img, kernel_size, sigma) + + if not isinstance(img, torch.Tensor): + output = to_pil_image(output) + return output + +# --------------- + + + + +def normalize(x, dim=1, eps=1e-8): + return x / (x.norm(dim=dim, keepdim=True) + eps) + + +def rot_inner_all(x): + num = x.shape[0] + + image_size = x.shape[2] + + R = x.repeat(4, 1, 1, 1) + a = x.permute(0, 1, 3, 2) + + a = a.view(num, 3, 2, image_size//2, image_size) + a = a.permute(2, 0, 1, 3, 4) + s1 = a[0] + s2 = a[1] + s1_1 = torch.rot90(s1, 2, (2, 3)) + s2_2 = torch.rot90(s2, 2, (2, 3)) + R[num: 2 * num] = torch.cat((s1_1.unsqueeze(2), s2.unsqueeze(2)), dim=2).reshape(num,3, image_size, image_size).permute(0, 1, 3, 2) + R[3 * num:] = torch.cat((s1.unsqueeze(2), s2_2.unsqueeze(2)), dim=2).reshape(num,3, image_size, image_size).permute(0, 1, 3, 2) + R[2 * num: 3 * num] = torch.cat((s1_1.unsqueeze(2), s2_2.unsqueeze(2)), dim=2).reshape(num,3, image_size, image_size).permute(0, 1, 3, 2) + return R + + +def Rotation(x, y): + num = x.shape[0] + X = rot_inner_all(x) + y = y.repeat(16) + for i in range(1, 16): + y[i * num:(i + 1) * num]+=1000 * i + return torch.cat((X, torch.rot90(X, 1, (2, 3)), torch.rot90(X, 2, (2, 3)), torch.rot90(X, 3, (2, 3))), dim=0), y + + + + + +def get_similarity_matrix(outputs, chunk=2, multi_gpu=False): + ''' + Compute similarity matrix + - outputs: (B', d) tensor for B' = B * chunk + - sim_matrix: (B', B') tensor + + Code Reference: + https://github.com/gydpku/OCM/blob/main/test_cifar10.py + ''' + if multi_gpu: + outputs_gathered = [] + for out in outputs.chunk(chunk): + gather_t = [torch.empty_like(out) for _ in range(dist.get_world_size())] + gather_t = torch.cat(distops.all_gather(gather_t, out)) + outputs_gathered.append(gather_t) + outputs = torch.cat(outputs_gathered) + sim_matrix = torch.mm(outputs, outputs.t()) + + return sim_matrix + + +def Supervised_NT_xent_n(sim_matrix, labels, embedding=None,temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): + ''' + Compute NT_xent loss + - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) + + Code Reference: + https://github.com/gydpku/OCM/blob/main/test_cifar10.py + ''' + device = sim_matrix.device + labels1 = labels.repeat(2) + logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) + sim_matrix = sim_matrix - logits_max.detach() + B = sim_matrix.size(0) // chunk # B = B' / chunk + eye = torch.eye(B * chunk).to(device) # (B', B') + sim_matrix = torch.exp(sim_matrix / temperature) * (1 - eye) + denom = torch.sum(sim_matrix, dim=1, keepdim=True) + sim_matrix = -torch.log(sim_matrix/(denom + eps) + eps) + labels1 = labels1.contiguous().view(-1, 1) + Mask1 = torch.eq(labels1, labels1.t()).float().to(device) + Mask1 = Mask1 / (Mask1.sum(dim=1, keepdim=True) + eps) + loss1 = 2 * torch.sum(Mask1 * sim_matrix) / (2 * B) + return (torch.sum(sim_matrix[:B, B:].diag() + sim_matrix[B:, :B].diag()) / (2 * B)) + loss1 + + +def Supervised_NT_xent_uni(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): + ''' + Compute NT_xent loss + - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) + + Code Reference: + https://github.com/gydpku/OCM/blob/main/test_cifar10.py + ''' + device = sim_matrix.device + labels1 = labels.repeat(2) + logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) + sim_matrix = sim_matrix - logits_max.detach() + B = sim_matrix.size(0) // chunk + sim_matrix = torch.exp(sim_matrix / temperature) + denom = torch.sum(sim_matrix, dim=1, keepdim=True) + sim_matrix = - torch.log(sim_matrix / (denom + eps) + eps) + labels1 = labels1.contiguous().view(-1, 1) + Mask1 = torch.eq(labels1, labels1.t()).float().to(device) + Mask1 = Mask1 / (Mask1.sum(dim=1, keepdim=True) + eps) + return torch.sum(Mask1 * sim_matrix) / (2 * B) + + + + + +def Supervised_NT_xent_pre(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False): + ''' + Compute NT_xent loss + - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples) + + Code Reference: + https://github.com/gydpku/OCM/blob/main/test_cifar10.py + ''' + device = sim_matrix.device + labels1 = labels#.repeat(2) + logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True) + sim_matrix = sim_matrix - logits_max.detach() + B = sim_matrix.size(0) // chunk + sim_matrix = torch.exp(sim_matrix / temperature) + denom = torch.sum(sim_matrix, dim=1, keepdim=True) + sim_matrix = -torch.log(sim_matrix/(denom+eps)+eps) # loss matrix + labels1 = labels1.contiguous().view(-1, 1) + Mask1 = torch.eq(labels1, labels1.t()).float().to(device) + Mask1 = Mask1 / (Mask1.sum(dim=1, keepdim=True) + eps) + return torch.sum(Mask1 * sim_matrix) / (2 * B) + + + +######################################################### +# # +# Model # +# # +######################################################### + + + +class OCM_Model(nn.Module): + + def __init__(self, backbone, feat_dim, num_class, device): + ''' + A OCM model consists of a backbone, a classifier and a self-supervised head + ''' + + super(OCM_Model, self).__init__() + self.backbone = backbone + self.classifier = nn.Linear(feat_dim, num_class) + self.head = nn.Linear(feat_dim, 128) # for self-supervise + self.device = device + + def get_features(self, x): + out = self.backbone(x)['features'] + return out + + + def forward_head(self, x): + feat = self.get_features(x) + out = self.head(feat) + return feat, out + + + def forward_classifier(self, x): + feat = self.get_features(x) + logits = self.classifier(feat) + return logits + +class OCM(nn.Module): + + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super(OCM, self).__init__() + + # device setting + self.device = kwargs['device'] + + # current task index + self.cur_task_id = 0 + + # # current task class indexes + # self.cur_cls_indexes = None + + # Build model structure + self.model = OCM_Model(backbone, feat_dim, num_class, self.device) + + # Store old network + self.previous_model = None + + # Store all seen classes + self.class_holder = [] + + self.buffer_per_class = 7 + + + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.task_num = kwargs['task_num'] + self.image_size = kwargs['image_size'] + + self.simclr_aug = torch.nn.Sequential( + HorizontalFlipLayer().to(self.device), + RandomColorGrayLayer(p=0.25).to(self.device), + RandomResizedCropLayer(scale=(0.3, 1.0), size=[self.image_size, self.image_size, 3]).to(self.device) + ) + + def observe(self, data): + # get data and labels + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + + # update seen classes + Y = deepcopy(y) + for j in range(len(Y)): + if Y[j] not in self.class_holder: + self.class_holder.append(Y[j].detach()) + + + # learning + x = x.requires_grad_() + + if self.cur_task_id == 0: + pred, acc, loss = self.observe_first_task(x, y) + else: + pred, acc, loss = self.observe_incremental_tasks(x, y) + + # sample data to buffer + self.buffer.add_reservoir(x=x.detach(), y=y.detach(), task=self.cur_task_id) + + return pred, acc, loss + + + + def observe_first_task(self, x, y): + """ + Code Reference: + https://github.com/gydpku/OCM/blob/main/test_cifar10.py + """ + images1, rot_sim_labels = Rotation(x, y) + images_pair = torch.cat([images1, self.simclr_aug(images1)], dim=0) + rot_sim_labels = rot_sim_labels.cuda() + feature_map,outputs_aux = self.model.forward_head(images_pair) + simclr = normalize(outputs_aux) + feature_map_out = normalize(feature_map[:images_pair.shape[0]]) + num1 = feature_map_out.shape[1] - simclr.shape[1] + id1 = torch.randperm(num1)[0] + size = simclr.shape[1] + sim_matrix = torch.matmul(simclr, feature_map_out[:, id1 :id1+ 1 * size].t()) + sim_matrix += get_similarity_matrix(simclr) + loss_sim1 = Supervised_NT_xent_n(sim_matrix, labels=rot_sim_labels, temperature=0.07) + lo1 = loss_sim1 + y_pred = self.model.forward_classifier(self.simclr_aug(x)) + loss = F.cross_entropy(y_pred, y) + lo1 + pred = torch.argmin(y_pred, dim=1) + acc = torch.sum(pred == y).item() / x.size(0) + + return y_pred, acc, loss + + + + def observe_incremental_tasks(self, x, y): + """ + Code Reference: + https://github.com/gydpku/OCM/blob/main/test_cifar10.py + """ + buffer_batch_size = min(64, self.buffer_per_class*len(self.class_holder)) + mem_x, mem_y,_ = self.buffer.sample(buffer_batch_size, exclude_task=None) + mem_x = mem_x.requires_grad_() + images1, rot_sim_labels = Rotation(x, y) + images1_r, rot_sim_labels_r = Rotation(mem_x, + mem_y) + images_pair = torch.cat([images1, self.simclr_aug(images1)], dim=0) + images_pair_r = torch.cat([images1_r, self.simclr_aug(images1_r)], dim=0) + t = torch.cat((images_pair,images_pair_r),dim=0) + feature_map, u = self.model.forward_head(t) + pre_u_feature, pre_u = self.previous_model.forward_head(images1_r) + feature_map_out = normalize(feature_map[:images_pair.shape[0]]) + feature_map_out_r = normalize(feature_map[images_pair.shape[0]:]) + images_out = u[:images_pair.shape[0]] + images_out_r = u[images_pair.shape[0]:] + pre_u = normalize(pre_u) + simclr = normalize(images_out) + simclr_r = normalize(images_out_r) + num1 = feature_map_out.shape[1] - simclr.shape[1] + id1 = torch.randperm(num1)[0] + id2 = torch.randperm(num1)[0] + size = simclr.shape[1] + + sim_matrix = torch.matmul(simclr, feature_map_out[:, id1:id1 + size].t()) + sim_matrix_r = torch.matmul(simclr_r, feature_map_out_r[:, id2:id2 + size].t()) + sim_matrix += get_similarity_matrix(simclr) + sim_matrix_r += get_similarity_matrix(simclr_r) + sim_matrix_r_pre = torch.matmul(simclr_r[:images1_r.shape[0]],pre_u.t()) + loss_sim_r =Supervised_NT_xent_uni(sim_matrix_r,labels=rot_sim_labels_r,temperature=0.07) + loss_sim_pre = Supervised_NT_xent_pre(sim_matrix_r_pre, labels=rot_sim_labels_r, temperature=0.07) + loss_sim = Supervised_NT_xent_n(sim_matrix, labels=rot_sim_labels, temperature=0.07) + lo1 = loss_sim_r + loss_sim + loss_sim_pre + y_label = self.model.forward_classifier(self.simclr_aug(mem_x)) + y_label_pre = self.previous_model.forward_classifier(self.simclr_aug(mem_x)) + loss = F.cross_entropy(y_label, mem_y) + lo1 + F.mse_loss(y_label_pre[:, :self.prev_cls_num], + y_label[:, + :self.prev_cls_num]) + + with torch.no_grad(): + logits = self.model.forward_classifier(x)[:, :self.accu_cls_num] + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() / x.size(0) + return logits, acc, loss + + + + + def inference(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logits = self.model.forward_classifier(x) + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + # load buffer to the models + if self.cur_task_id == 0: + self.buffer = buffer + + if self.cur_task_id == 0: + self.accu_cls_num = self.init_cls_num + else: + self.accu_cls_num += self.inc_cls_num + + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + self.prev_cls_num = self.accu_cls_num + self.cur_task_id += 1 + self.previous_model = deepcopy(self.model) + + + def get_parameters(self, config): + return self.model.parameters() \ No newline at end of file diff --git a/core/model/praka.py b/core/model/praka.py new file mode 100644 index 0000000000000000000000000000000000000000..bb49551517b798cbf785c5834d560929d365d41f --- /dev/null +++ b/core/model/praka.py @@ -0,0 +1,340 @@ +""" +@inproceedings{DBLP:conf/iccv/ShiY23, + title = {Prototype Reminiscence and Augmented Asymmetric Knowledge Aggregation for Non-Exemplar Class-Incremental Learning}, + author = {Shi, Wuxuan and Ye, Mang}, + booktitle = {2023 IEEE/CVF International Conference on Computer Vision (ICCV)}, + pages = {1772-1781}, + publisher = {Computer Vision Foundation / {IEEE}}, + year = {2023} +} + +https://openaccess.thecvf.com/content/ICCV2023/papers/Shi_Prototype_Reminiscence_and_Augmented_Asymmetric_Knowledge_Aggregation_for_Non-Exemplar_Class-Incremental_ICCV_2023_paper.pdf + +Adapted from https://github.com/ShiWuxuan/PRAKA +""" + +from torch.nn import functional as F +import os +import numpy as np +import torch +import torch.nn as nn +import math +import copy +from core.model import Finetune + +class joint_network(nn.Module): + def __init__(self, numclass, feature_extractor): + ''' + Code Reference: + https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py + ''' + super(joint_network, self).__init__() + self.feature = feature_extractor + self.fc = nn.Linear(512, numclass * 4, bias=True) + self.classifier = nn.Linear(512, numclass, bias=True) + + def forward(self, input): + ''' + Code Reference: + https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py + ''' + x = self.feature(input) + x = self.classifier(x) + return x + + def Incremental_learning(self, numclass): + ''' + Update the fully connected (fc) layer and classifier layer to accommodate the new number of classes. + + This function modifies the output dimensions of the model's fully connected layer (`fc`) + and the classifier layer based on the total number of classes after the current task. + It ensures that the new layers retain the weights and biases from the previous configuration + for the classes that were previously learned. + + Parameters: + - numclass (int): The total number of classes after the current task, including both old and new classes. + + Notes: + - The `fc` layer's output dimension is set to `numclass * 4`. + - The classifier layer is adjusted to match the new total number of classes, while retaining the previously learned weights and biases. + + Code Reference: + https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py + ''' + weight = self.fc.weight.data + bias = self.fc.bias.data + in_feature = self.fc.in_features + out_feature = self.fc.out_features + + self.fc = nn.Linear(in_feature, numclass * 4, bias=True) + self.fc.weight.data[:out_feature] = weight[:out_feature] + self.fc.bias.data[:out_feature] = bias[:out_feature] + + weight = self.classifier.weight.data + bias = self.classifier.bias.data + in_feature = self.classifier.in_features + out_feature = self.classifier.out_features + + self.classifier = nn.Linear(in_feature, numclass, bias=True) + self.classifier.weight.data[:out_feature] = weight[:out_feature] + self.classifier.bias.data[:out_feature] = bias[:out_feature] + + def feature_extractor(self, inputs): + ''' + Code Reference: + https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/myNetwork.py + ''' + return self.feature(inputs) + +class PRAKA(nn.Module): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + #super().__init__(backbone, feat_dim, num_class, **kwargs) + super().__init__() + self.device = kwargs['device'] + self.kwargs = kwargs + self.size = 32 + # Initialize the feature extractor with a custom ResNet18 structure. + encoder = backbone + self.model = joint_network(kwargs["init_cls_num"], encoder) + self.radius = 0 + self.prototype = None + self.numsamples = None + self.numclass = kwargs["init_cls_num"] + self.task_size = kwargs["inc_cls_num"] + self.old_model = None + # save the model and its corresponding task_id + self.task_idx = 0 + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + if task_idx > 0: + self.model.Incremental_learning(self.numclass) + self.model.to(self.device) + + def observe(self, data): + ''' + Processes a batch of training data to compute predictions, accuracy, and loss. + + Parameters: + - data: Dictionary containing the batch of training samples + - 'image': Tensor of input images + - 'label': Tensor of ground truth labels + + Returns: + - predictions: Tensor of predicted class labels for the input images + - accuracy: Float value representing the accuracy of the model on the current batch + - loss: Float value representing the computed loss for the batch + + Description: + This function is called during the training phase. It performs the following steps: + 1. Extracts the images and labels from the provided data dictionary and transfers them to the device. + 2. Augments the images by rotating them by 0, 90, 180, and 270 degrees, and creates corresponding labels for these augmented images. + 3. Computes the loss using the augmented images and labels. + 4. Evaluates the model's performance on the current batch by calculating the accuracy and loss. + 5. Returns the predictions, accuracy, and loss for the batch. + + Example Usage: + predictions, accuracy, loss = observe(data) + ''' + images, labels = data['image'].to(self.device), data['label'].to(self.device) + + # Generate four times the number of images by rotating each image 0°, 90°, 180°, and 270°. + images = torch.stack([torch.rot90(images, k, (2, 3)) for k in range(4)], 1) + images = images.view(-1, 3, self.size, self.size) + # Generate corresponding labels for the rotated images, each original label produces four new labels. + joint_labels = torch.stack([labels * 4 + k for k in range(4)], 1).view(-1) + if self.task_idx == 0: + old_class = 0 + else: + old_class = self.kwargs['init_cls_num'] + self.kwargs['inc_cls_num'] * (self.task_idx - 1) + # Compute loss and predictions for a batch + loss, single_preds = self._compute_loss(images, joint_labels, labels, old_class) + + preds = torch.argmax(single_preds, dim=-1) + return preds, (preds == labels).sum().item() / len(labels), loss + + def inference(self, data): + ''' + Performs inference on a batch of test samples and computes the classification results and accuracy. + + Parameters: + - data: Dictionary containing the batch of test samples + - 'image': Tensor of input images + - 'label': Tensor of ground truth labels + + Returns: + - predictions: Tensor of predicted class labels for the input images + - accuracy: Float value representing the accuracy of the model on the current batch + + Example Usage: + predictions, accuracy = inference(data) + ''' + + imgs, labels = data['image'].to(self.device), data['label'].to(self.device) + + preds = torch.argmax(self.model(imgs), dim=-1) + + return preds, (preds == labels).sum().item() / len(labels) + + def _compute_loss(self, imgs, joint_labels, labels, old_class=0): + ''' + Computes the loss for a batch of images and labels. + + Parameters: + - imgs: Tensor of input images + - joint_labels: Tensor of labels for images augmented with rotations (0°, 90°, 180°, 270°) + - labels: Tensor of ground truth labels for the images + - old_class: Integer indicating the number of old classes (default is 0) + + Returns: + - loss: Scalar tensor representing the total computed loss + - preds: Tensor of predictions for the original (non-augmented) images + + Example Usage: + loss, preds = self._compute_loss(imgs, joint_labels, labels, old_class) + + + Code Reference: + https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py + ''' + # Feature extraction + feature = self.model.feature(imgs) + + # Classification predictions + joint_preds = self.model.fc(feature) + single_preds = self.model.classifier(feature)[::4] + joint_preds, joint_labels, single_preds, labels = joint_preds.to(self.device), joint_labels.to(self.device), single_preds.to(self.device), labels.to(self.device) + joint_loss = nn.CrossEntropyLoss()(joint_preds/self.kwargs["temp"], joint_labels) + single_loss = nn.CrossEntropyLoss()(single_preds/self.kwargs["temp"], labels) + + # Average loss for images generated by rotating 4 angles + agg_preds = 0 + for i in range(4): + agg_preds = agg_preds + joint_preds[i::4, i::4] / 4 + # Compute distillation loss between single predictions and aggregated predictions + distillation_loss = F.kl_div(F.log_softmax(single_preds, 1), + F.softmax(agg_preds.detach(), 1), + reduction='batchmean') + if old_class == 0: + return joint_loss + single_loss + distillation_loss, single_preds + else: + feature_old = self.old_model.feature(imgs) + + loss_kd = torch.dist(feature, feature_old, 2) + + # Prototype augmentation + proto_aug = [] + proto_aug_label = [] + old_class_list = list(self.prototype.keys()) + for _ in range(feature.shape[0] // 4): # batch_size = feature.shape[0] // 4 + i = np.random.randint(0, feature.shape[0]) + np.random.shuffle(old_class_list) + lam = np.random.beta(0.5, 0.5) + if lam > 0.6: + lam = lam * 0.6 + + if np.random.random() >= 0.5: + # Weighted combination of prototype (fixed image from old dataset) and current feature + temp = (1 + lam) * self.prototype[old_class_list[0]] - lam * feature.detach().cpu().numpy()[i] + else: + temp = (1 - lam) * self.prototype[old_class_list[0]] + lam * feature.detach().cpu().numpy()[i] + + # Append the generated augmented features and corresponding labels to proto_aug and proto_aug_label + proto_aug.append(temp) + proto_aug_label.append(old_class_list[0]) + + proto_aug = torch.from_numpy(np.float32(np.asarray(proto_aug))).float().to(self.device) + proto_aug_label = torch.from_numpy(np.asarray(proto_aug_label)).to(self.device) + aug_preds = self.model.classifier(proto_aug) + joint_aug_preds = self.model.fc(proto_aug) + agg_preds = joint_aug_preds[:, ::4] + aug_distillation_loss = F.kl_div(F.log_softmax(aug_preds, 1), + F.softmax(agg_preds.detach(), 1), + reduction='batchmean') + # Calculate the weighted sum of cross-entropy loss and distillation loss for augmented data + loss_protoAug = nn.CrossEntropyLoss()(aug_preds/self.kwargs["temp"], proto_aug_label) + nn.CrossEntropyLoss()(joint_aug_preds/self.kwargs["temp"], proto_aug_label*4) + aug_distillation_loss + return joint_loss + single_loss + distillation_loss + self.kwargs["protoAug_weight"]*loss_protoAug + self.kwargs["kd_weight"]*loss_kd, single_preds + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Perform operations after completing the training for a specific task. + 1. Save the prototypes of the current model. + 2. Save the current model state to a file. + 3. Load the saved model state as the old model for future reference. + + Parameters: + - task_idx (int): The index of the current task. + - buffer: Data buffer for storing samples (not used in this function). + - train_loader (DataLoader): DataLoader for the training dataset of the current task. + - test_loaders (list of DataLoader): List of DataLoaders for test datasets of different tasks. + + Example Usage: + self.after_task(task_idx, buffer, train_loader, test_loaders) + ''' + # Save the prototype + self.protoSave(self.model, train_loader, self.task_idx) + self.numclass += self.task_size + + self.old_model = copy.deepcopy(self.model) + self.old_model.eval() + + def protoSave(self, model, loader, current_task): + ''' + Save the class prototypes for the current task. + + This function extracts features from the data using the provided model and computes + class prototypes based on these features. The prototypes are then saved to the class + attributes. If it's the first task, the prototypes are initialized. For subsequent + tasks, the prototypes are updated with new class information. + + Parameters: + - model: The model used for feature extraction. + - loader: DataLoader providing the dataset for the current task. + - current_task (int): The index of the current task. + + Code Reference: + https://github.com/ShiWuxuan/PRAKA/blob/master/Cifar100/jointSSL.py + ''' + + features = [] + labels = [] + model.eval() + # Feature extraction + with torch.no_grad(): + for i, batch in enumerate(loader): + images, target = batch['image'], batch['label'] + feature = model.feature(images.to(self.device)) + if feature.shape[0] == loader.batch_size: + labels.append(target.numpy()) + features.append(feature.cpu().numpy()) + + labels_set = np.unique(labels) + labels = np.array(labels) + labels = np.reshape(labels, labels.shape[0] * labels.shape[1]) + features = np.array(features) + features = np.reshape(features, (features.shape[0] * features.shape[1], features.shape[2])) + + # Compute class prototypes + prototype = {} + class_label = [] + numsamples = {} + + for item in labels_set: + index = np.where(item == labels)[0] + class_label.append(item) + feature_classwise = features[index] + prototype[item] = np.mean(feature_classwise, axis=0) + # Record the number of samples for each class. + numsamples[item] = feature_classwise.shape[0] + if current_task == 0: + self.prototype = prototype + self.class_label = class_label + self.numsamples = numsamples + else: + self.prototype.update(prototype) + self.class_label = np.concatenate((class_label, self.class_label), axis=0) + self.numsamples.update(numsamples) + + def get_parameters(self, config): + return self.model.parameters() + diff --git a/core/model/ranpac.py b/core/model/ranpac.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1b29a5ef5de20a42c0ea3ce355b7b0872f895d --- /dev/null +++ b/core/model/ranpac.py @@ -0,0 +1,269 @@ +''' +@misc{mcdonnell2024ranpacrandomprojectionspretrained, + title={RanPAC: Random Projections and Pre-trained Models for Continual Learning}, + author={Mark D. McDonnell and Dong Gong and Amin Parveneh and Ehsan Abbasnejad and Anton van den Hengel}, + year={2024}, + eprint={2307.02251}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2307.02251}, +} + +Code Reference: +https://github.com/RanPAC/RanPAC +''' + +import copy +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone.transformer import MultiHeadAttention_LoRA, VisionTransformer +from .backbone.clip import CLIP, tokenize +from .backbone.vit import ViTZoo, ViT_in21k_adapter + +VIT = ViT_in21k_adapter +CLIP = CLIP + +class CosineLinear(nn.Module): + def __init__(self, in_features, out_features): + + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features)) + self.sigma = nn.Parameter(torch.Tensor(1)) + self.reset_parameters() + + self.use_RP = False + self.W_rand = None + + def reset_parameters(self): + + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + self.sigma.data.fill_(1) + + def forward(self, input): + + if not self.use_RP: + out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) + else: + if self.W_rand is not None: + inn = F.relu(input @ self.W_rand) + else: + assert 0, 'should not reach here, for now' + inn = input + out = F.linear(inn, self.weight) + + out = self.sigma * out + + return out + +class Network(nn.Module): + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self._cur_task_id = -1 + self.backbone = backbone + self.device = device + self.classifier = None + + if isinstance(self.backbone, VIT): + self.feature_dim = self.backbone.feat_dim + elif isinstance(self.backbone, CLIP): + # Assuming the final features_dim is concat of image and text + self.feature_dim = self.backbone.visual.output_dim + self.backbone.transformer.width + self.accm_class_names = [] + self.curr_class_names = [] + self.accm_text_tokens = None + self.curr_text_tokens = None + + self.prompt_template = kwargs['prompt_template'] + + def update_classifer(self, num_classes, train_loader): + + if isinstance(self.backbone, VIT): + pass + elif isinstance(self.backbone, CLIP): + self.curr_class_names = train_loader.dataset.get_class_names() + self.accm_class_names += self.curr_class_names + + self.curr_text_tokens = tokenize( + [self.prompt_template.format(c) for c in self.curr_class_names] + ).to(self.device) + + self.accm_text_tokens = tokenize( + [self.prompt_template.format(c) for c in self.accm_class_names] + ).to(self.device) + else: + assert 0 + + self._cur_task_id += 1 + del self.classifier + self.classifier = CosineLinear(self.feature_dim, num_classes).to(self.device) + + def get_feature(self, x): + + if isinstance(self.backbone, VIT): + return self.backbone(x) + elif isinstance(self.backbone, CLIP): + features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens) + + max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) # Shape will be [48] + max_features = features_text[max_indices] # Shape will be [48, 768] + + return torch.cat([features_image, max_features], dim=1) # Shape will be [48, 1536] + else: + assert 0 + + def forward(self, x, inference=False): + + if isinstance(self.backbone, VIT): + features = self.backbone(x) + elif isinstance(self.backbone, CLIP): + if inference: + features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.accm_text_tokens) + else: + features_image, features_text, logits_per_image, logits_per_text = self.backbone(x, self.curr_text_tokens) + + max_indices = logits_per_image.softmax(dim=-1).argmax(dim=1) # Shape will be [48] + max_features = features_text[max_indices] # Shape will be [48, 768] + features = torch.cat([features_image, max_features], dim=1) # Shape will be [48, 1536] + else: + assert 0 + + return self.classifier(features) + +class RanPAC(nn.Module): + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self._network = Network(backbone, device, **kwargs) + + self.device = device + self.first_session_training = kwargs["first_session_training"] + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.total_cls_num = kwargs['total_cls_num'] + self.task_num = kwargs["task_num"] + #self.use_RP = kwargs["use_RP"] + self.M = kwargs['M'] + + self._known_classes = 0 + self._classes_seen_so_far = 0 + self._skip_train = False # this flag is used to skip training + + self._network.to(self.device) + + if isinstance(backbone, CLIP): + for name, param in self._network.named_parameters(): + if 'adapt' not in name: + param.requires_grad = False + + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + if task_idx == 0: + self._classes_seen_so_far = self.init_cls_num + elif task_idx > 0: + self._classes_seen_so_far += self.inc_cls_num + + self._network.update_classifer(self._classes_seen_so_far, train_loader) + + if task_idx == 0 and self.first_session_training: + self._skip_train = False + else: + self._skip_train = True + print(f"Not training on task {task_idx}") + + def observe(self, data): + + if self._skip_train: + # set required_grad be True so that it can call backward() but don't do anything + return None, 0., torch.tensor(0., device = self.device, requires_grad = True) + + inputs, targets = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + + logits = self._network(inputs) + loss = F.cross_entropy(logits, targets) + + _, preds = torch.max(logits, dim=1) + correct = preds.eq(targets.expand_as(preds)).sum().item() + total = len(targets) + + acc = round(correct / total, 4) + + return preds, acc, loss + + def inference(self, data): + + inputs, targets = data['image'].to(self.device), data['label'] + logits = self._network(inputs, True) + _, preds = torch.max(logits, dim=1) + + correct = preds.cpu().eq(targets.expand_as(preds)).sum().item() + total = len(targets) + + acc = round(correct / total, 4) + + return logits, acc + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + self._known_classes = self._classes_seen_so_far + + if task_idx == 0: + + # Initialize attribute for random projection classifier + self.W_rand = torch.randn(self._network.classifier.in_features, self.M) + self.Q = torch.zeros(self.M, self.init_cls_num) # C + self.G = torch.zeros(self.M, self.M) + + else: + self.Q = torch.cat((self.Q, torch.zeros(self.M, self.inc_cls_num)), dim=1) + + self.update_rp_classifier(train_loader, test_loaders[0].dataset.trfms) + + @torch.no_grad() + def update_rp_classifier(self, train_loader, test_trfms): + + self._network.eval() + train_loader.dataset.trfms = test_trfms + + self._network.classifier.use_RP = True + self._network.classifier.W_rand = self.W_rand.to(self.device) # feature_dim x M + + feature_list, label_list = [], [] + for batch in train_loader: + x, y = batch['image'].to(self.device), batch['label'] + feature_list.append(self._network.get_feature(x).cpu()) + label_list.append(y) + feature_list, label_list = torch.cat(feature_list, dim = 0), torch.cat(label_list, dim = 0) + + label_list = F.one_hot(label_list, self._classes_seen_so_far).to(torch.float32) + + proj_feature_list = F.relu(feature_list @ self.W_rand) + + self.Q += proj_feature_list.T @ label_list + self.G += proj_feature_list.T @ proj_feature_list + + ridges = 10.0**np.arange(-8,9) + num_val_samples = int(proj_feature_list.shape[0] * 0.8) + losses = [] + Q_val = proj_feature_list[:num_val_samples, :].T @ label_list[:num_val_samples, :] + G_val = proj_feature_list[:num_val_samples, :].T @ proj_feature_list[:num_val_samples, :] + for ridge in ridges: + Wo = torch.linalg.solve(G_val + ridge * torch.eye(self.M), Q_val).T #better nmerical stability than .inv + Y_train_pred = proj_feature_list[num_val_samples:, :] @ Wo.T + losses.append(F.mse_loss(Y_train_pred, label_list[num_val_samples:, :])) + ridge = ridges[np.argmin(np.array(losses))] + print(f"Optimal lambda: {ridge}") + + Wo = torch.linalg.solve(self.G + ridge * torch.eye(self.M), self.Q).T #better nmerical stability than .inv + self._network.classifier.weight.data = Wo[:self._network.classifier.weight.shape[0], :].to(self.device) # num_classes x M + + def get_parameters(self, config): + return self._network.parameters() \ No newline at end of file diff --git a/core/model/rapf.py b/core/model/rapf.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd2e63479d62e4dbabda5ef9b7c3e1694db6fca --- /dev/null +++ b/core/model/rapf.py @@ -0,0 +1,377 @@ +import copy + +import torch +import torch.nn as nn +import numpy as np +import os +import random + +from tqdm import tqdm + +from .backbone.clip import tokenize +from core.data import dataloader +from core.model import backbone +from core.model.finetune import Finetune +from torch.utils.data import DataLoader + + +def get_class_ids_per_task(init_cls_num, inc_cls_num, class_order): + yield class_order[:init_cls_num] + for i in range(init_cls_num, len(class_order), inc_cls_num): + yield class_order[i:i + inc_cls_num] + +def get_class_names(classes_names, prev_cls_num, accu_cls_num): + return [classes_names[i] for i in range(prev_cls_num, accu_cls_num)] + +def shrink_cov(cov): + diag_mean = torch.mean(torch.diagonal(cov)) + off_diag = cov.clone() + off_diag.fill_diagonal_(0.0) + mask = off_diag != 0.0 + off_diag_mean = (off_diag*mask).sum() / mask.sum() + iden = torch.eye(cov.shape[0], device=cov.device) + alpha1 = 1 + alpha2 = 1 + cov_ = cov + (alpha1*diag_mean*iden) + (alpha2*off_diag_mean*(1-iden)) + return cov_ +def sample(mean, cov, size, shrink=False): + vec = torch.randn(size, mean.shape[-1], device=mean.device) + if shrink: + cov = shrink_cov(cov) + sqrt_cov = torch.linalg.cholesky(cov) + vec = vec @ sqrt_cov.t() + vec = vec + mean + return vec + +def seed_everything(seed=0): + """Fix all random seeds""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + os.environ['PYTHONHASHSEED'] = str(seed) + + +""" +This clas refer to the following repository: +https://github.com/linlany/RAPF +""" +class ClassIncrementalCLIP(nn.Module): + def __init__(self, model, **kwargs): + super().__init__() + device = kwargs['device'] + fp16 = kwargs['fp16'] + mix_bias = kwargs['mix_bias'] + self.prompt_template = kwargs['prompt_template'] + self.initial_increment = kwargs['init_cls_num'] + self.increment = kwargs['inc_cls_num'] + self.device = device + self.classes_names = None + # self.class_order = kwargs['class_order'] + self.visual = model.visual + self.transformer = model.transformer + self.positional_embedding = model.positional_embedding + self.token_embedding = model.token_embedding + self.ln_final = model.ln_final + self.text_projection = model.text_projection + self.logit_scale = model.logit_scale + # pdb.set_trace() + # self.class_ids_per_task = list(get_class_ids_per_task(self.initial_increment, self.increment, self.class_order)) + self.current_class_names = [] + self.text_tokens = None + self.dtype = torch.float16 if fp16 else torch.float32 + self.adapter = nn.Linear(512, 512, bias=False ,device=device) + self.clip_type = model.dtype + + + # old adapter + self.old_adapter = None + self.old_edge_samples = [] + self.old_edge_samples_labels = [] + self.old_edge_samples_nearest_labels = [] + + # class stat + self.class_mean_list = [] + self.class_cov_list = [] + + self.class_diff = None + self.nearest_class = None + self.class_edge_distance = [] + self.mix_b = mix_bias + + def encode_text(self, text, prompt=False): + x = self.token_embedding(text).type(self.clip_type) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding.type(self.clip_type) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def encode_image(self, image): + # 确保输入数据类型与 self.visual 的权重类型一致 + image = image.to(self.clip_type) + return self.visual(image) + + + @torch.no_grad() + def get_class_name_features(self): + class_name_features = self.encode_text(self.text_tokens) + return class_name_features.type(torch.float32) + + def forward(self, image, ori_ima_f=False, memory_data=None, not_ini=False, edge_sample=None, prompt=False): + image = image.type(torch.float16) + with torch.no_grad(): + text_features = self.encode_text(self.text_tokens) + + + with torch.no_grad(): + image_features = self.encode_image(image) + original_image_features = image_features.clone() + if memory_data is not None: + memory_data = memory_data.type(self.dtype) + image_features = torch.cat([image_features, memory_data], dim=0) + if edge_sample is not None: + edge_sample = edge_sample.type(self.dtype) + edge_num = edge_sample.shape[0] + image_features = torch.cat([image_features, edge_sample], dim=0) + + image_features = self.adapter(image_features.type(self.dtype).detach()).type(self.clip_type) + + image_features = image_features / image_features.norm(dim=1, keepdim=True) + if edge_sample is not None: + edge_sample_features = image_features[-edge_num:] + image_features = image_features[:-edge_num] + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t().type(image_features.dtype) + + probs = logits_per_image + if not_ini: + with torch.no_grad(): + old_memory_feature = self.old_adapter(memory_data) + old_memory_feature = old_memory_feature / old_memory_feature.norm(dim=1, keepdim=True) + if edge_sample is not None: + return probs, image_features, old_memory_feature, edge_sample_features + return probs, image_features, old_memory_feature, text_features + if ori_ima_f: + if memory_data is not None: + image_features = image_features[:-memory_data.shape[0]] + return probs, original_image_features, image_features + return probs, image_features, None, None + + def adaptation(self, task_id, prev_cls_num, accu_cls_num, threshold=0): + self.current_class_names += get_class_names(self.classes_names, prev_cls_num, accu_cls_num) + self.text_tokens = tokenize( + [self.prompt_template.format(c) for c in self.current_class_names] + ).to(self.device) + self.text_end = self.text_tokens.max(dim=-1)[1] + self.class_name_features = self.get_class_name_features() + self.class_name_features = self.class_name_features / self.class_name_features.norm(dim=-1, p=2, keepdim=True) + self.queue_empty = True + self.hard_pairs = None + if task_id>0: + self.old_adapter = copy.deepcopy(self.adapter) + dist_list = [] + for k, class_name_feature in enumerate(self.class_name_features[:prev_cls_num]): + diff = torch.cdist(self.class_name_features[prev_cls_num:].type(torch.float32), class_name_feature.unsqueeze(0).type(torch.float32)).squeeze() + dist_list.append(diff) + dist_list = torch.stack(dist_list) + self.class_diff = dist_list + mask = self.class_diff < threshold + indices = torch.nonzero(mask) + self.hard_new_class = torch.unique(indices[:,1]) + self.initial_increment+(task_id-1) * self.increment + num_hard_class = self.hard_new_class.shape[0] + self.hard_pairs = indices + self.hard_pairs[:,1] = self.hard_pairs[:,1]+self.initial_increment+(task_id-1) * self.increment + def get_old_edge_samples(self, batch_size): + random_select = torch.randperm(self.old_edge_samples.shape[0])[:batch_size] + return self.old_edge_samples[random_select], self.old_edge_samples_labels[random_select], self.old_edge_samples_nearest_labels[random_select] + + + def analyze_mean_cov(self, features, labels): + label = torch.sort(torch.unique(labels))[0] + for l in label: + index = torch.nonzero(labels == l) + index = index.squeeze() + class_data = features[index] + mean = class_data.mean(dim=0) + cov = torch.cov(class_data.t()) + 1e-4* torch.eye(class_data.shape[-1], device=class_data.device) + distance = torch.cdist(class_data, mean.unsqueeze(0)).squeeze() + max_distance = torch.sort(distance)[0][-10:] + self.class_edge_distance.append((max_distance.mean()-max_distance.min(), max_distance.max() - max_distance.mean(), max_distance.mean())) + self.class_mean_list.append(mean) + self.class_cov_list.append(cov) + + def mix_matrix(self): + if self.old_adapter is not None: + weight_new = self.adapter.weight.data + weight_old = self.old_adapter.weight.data + dist = (weight_new - weight_old).abs() + U_old, S_old, V_old = torch.linalg.svd(weight_old) + P_new = U_old.T @ weight_new + dist = (P_new - torch.diag(S_old)@V_old).abs() + mask = dist / dist.max() + mask += self.mix_b + mask = torch.clamp(mask, max=1) + right = P_new * mask + torch.diag(S_old)@V_old * (1-mask) + weight = U_old @ right + self.adapter.weight.data = weight + +""" +This clas refer to the following repository: +https://github.com/linlany/RAPF +""" +class RAPF(nn.Module): + def __init__(self, backbone, **kwargs): + super().__init__() + seed = kwargs['seed'] + seed_everything(seed) + self.backbone = backbone + self.kwargs = kwargs + self.model = ClassIncrementalCLIP(self.backbone, **kwargs) + self.device = kwargs['device'] + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + self.beta = kwargs['beta'] + self.shrinkage = kwargs['shrinkage'] + self.threshold = kwargs['threshold'] + self.train_batch_size = kwargs['train_batch_size'] + self.batch_size = kwargs['batch_size'] + self.num_workers = kwargs['num_workers'] + self.seed = seed + + self.prev_cls_num = 0 + self.accu_cls_num = 0 + + + + def before_task(self, task_id, buffer, train_loader, test_loaders): + self.task_id = task_id + if self.task_id == 0: + self.accu_cls_num = self.init_cls_num + else: + self.accu_cls_num += self.inc_cls_num + + self.model.adaptation(task_id, self.prev_cls_num, self.accu_cls_num, self.threshold) + if self.task_id > 0: + random_class_order_list = list(range(self.init_cls_num+(self.task_id-1)*self.inc_cls_num)) + random.shuffle(random_class_order_list) + self.random_class_order_list = random_class_order_list + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + sample_data = [] + sample_target = [] + sample_after_adapt_feature = [] + model = self.model + for batch in tqdm(train_loader, total=len(train_loader)): + feats = batch['image'] + target = batch['label'] + feats, target = feats.to(self.device), target.to(self.device) + with torch.no_grad(): + _, ori_ima_feat, after_adapt_feature = model(feats, ori_ima_f=True) + sample_data.append(ori_ima_feat) + sample_target.append(target) + sample_after_adapt_feature.append(after_adapt_feature) + sample_target = torch.cat(sample_target, dim=0) + sample_data = torch.cat(sample_data, dim=0) + sample_after_adapt_feature = torch.cat(sample_after_adapt_feature, dim=0) + model.analyze_mean_cov(sample_data, sample_target) + model.mix_matrix() + self.prev_cls_num = self.accu_cls_num + + def get_parameters(self, config): + return self.model.adapter.parameters() + + def observe(self, data): + loss = torch.tensor(0.0).to(self.device) + loss_c = torch.tensor(0.0).to(self.device) + loss_hinge = torch.tensor(0.0).to(self.device) + + inputs = data['image'] + targets = data['label'] + inputs, targets = inputs.to(self.device), targets.to(self.device) + sg_inputs = None + edge_sample = None + ori_targets = targets.clone() + model = self.model + if self.task_id > 0: + sg_inputs = [] + sg_targets = [] + # num of classes per batch. Ensure an epoch traverses all classes at least once. + # For exemple, if there are 100 classes and 50 batches per epoch , there will be 2 classes per batch. + + random_class_order_list = self.random_class_order_list + batch_id = data['batch_id'] + if self.inc_cls_num == 5: + list_for_one_batch = [random_class_order_list[batch_id*4%len(random_class_order_list)], random_class_order_list[(batch_id*4+1)%len(random_class_order_list)], random_class_order_list[(batch_id*4+2)%len(random_class_order_list)], random_class_order_list[(batch_id*4+3)%len(random_class_order_list)]] + else: + list_for_one_batch = [random_class_order_list[batch_id*2%len(random_class_order_list)], random_class_order_list[(batch_id*2+1)%len(random_class_order_list)]] + + + for i in list_for_one_batch: + sg_inputs.append(sample(model.class_mean_list[i], model.class_cov_list[i],int(10*self.beta), shrink=self.shrinkage)) + sg_targets.append(torch.ones(int(10*self.beta), dtype=torch.long, device=self.device)*i) + sg_inputs = torch.cat(sg_inputs, dim=0) + sg_targets = torch.cat(sg_targets, dim=0) + targets = torch.cat([targets, sg_targets], dim=0) + if model.hard_pairs is not None and model.hard_pairs.shape[0] > 0: + edge_sample = [] + edge_p_target = [] + edge_n_target = [] + for hard_pair in model.hard_pairs: + edge_sample.append(sample(model.class_mean_list[hard_pair[0]], model.class_cov_list[hard_pair[0]],int(20*self.beta), shrink=self.shrinkage)) + edge_p_target.append(torch.ones(int(20*self.beta), dtype=torch.long, device=self.device)*hard_pair[0]) + edge_n_target.append(torch.ones(int(20*self.beta), dtype=torch.long, device=self.device)*hard_pair[1]) + edge_sample = torch.cat(edge_sample, dim=0) + edge_p_target = torch.cat(edge_p_target, dim=0) + edge_n_target = torch.cat(edge_n_target, dim=0) + if self.task_id > 0: + not_ini = True + else: + not_ini = False + outputs, _, __, edge_sample_features = model(inputs, memory_data=sg_inputs, not_ini=not_ini, edge_sample=edge_sample, prompt=False) + + if self.task_id > 0: + if edge_sample is not None: + edge_sample_features = edge_sample_features / edge_sample_features.norm(dim=-1, keepdim=True) + edge_target_features = model.class_name_features[edge_p_target].type(edge_sample_features.dtype) + edge_target_features = edge_target_features / edge_target_features.norm(dim=-1, keepdim=True) + edge_nearest_class_features = model.class_name_features[edge_n_target].type(edge_sample_features.dtype) + edge_nearest_class_features = edge_nearest_class_features / edge_nearest_class_features.norm(dim=-1, keepdim=True) + loss_hinge = torch.relu(- (edge_sample_features * edge_target_features.clone().detach()).sum(-1) + (edge_sample_features * edge_nearest_class_features.clone().detach()).sum(-1) + 0.1).mean() + loss_c = torch.nn.functional.cross_entropy(outputs, targets.detach()) + if edge_sample is not None: + loss = loss_c + loss_hinge + else: + loss = loss_c + # Return tuple [pred, acc, loss] + # with torch.no_grad(): + # prob_outputs = torch.nn.functional.softmax(outputs, dim=-1) + predicted_labels = outputs.argmax(dim=1) + predicted_labels = predicted_labels[:ori_targets.size(0)] + corrects = (predicted_labels == ori_targets).sum().item() + total_predictions = ori_targets.size(0) + accuracy = corrects / total_predictions + return predicted_labels, accuracy, loss + + + def inference(self, data): + feats = data['image'] + target = data['label'] + feats, target = feats.to(self.device), target.to(self.device) + model = self.model + with torch.no_grad(): + outputs, _, __, ___ = model(feats, prompt=False) + prob_outputs = torch.nn.functional.softmax(outputs, dim=-1) + predicted_labels = prob_outputs.argmax(dim=1) + corrects = (predicted_labels == target).sum().item() + total_predictions = target.size(0) + accurcy = corrects / total_predictions + return prob_outputs, accurcy \ No newline at end of file diff --git a/core/model/sd_lora.py b/core/model/sd_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..9115eb0121e6353cb0bc45d185f62a94e8e7b16d --- /dev/null +++ b/core/model/sd_lora.py @@ -0,0 +1,210 @@ +""" +@misc{wu2025sdlorascalabledecoupledlowrank, + title={SD-LoRA: Scalable Decoupled Low-Rank Adaptation for Class Incremental Learning}, + author={Yichen Wu and Hongming Piao and Long-Kai Huang and Renzhen Wang and Wanhua Li and Hanspeter Pfister and Deyu Meng and Kede Ma and Ying Wei}, + year={2025}, + eprint={2501.13198}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2501.13198}, +} + +Adapted from https://github.com/WuYichen-97/SD-Lora-CL + +""" + +import torch +import torch.nn as nn +import copy +import numpy as np + +from torch.nn import functional as F +from .backbone.transformer import MultiHeadAttention_SDLoRA + +class Model(nn.Module): + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self._cur_task_id = -1 + self.backbone = backbone + self.device = device + self.embed_dim = kwargs["embd_dim"] + self.init_cls_num = kwargs['init_cls_num'] + self.inc_cls_num = kwargs['inc_cls_num'] + + def update_fc(self): + + self._cur_task_id += 1 + if self._cur_task_id == 0: + classifier = nn.Linear(self.embed_dim, self.init_cls_num, bias=True) + + nn.init.kaiming_uniform_(classifier.weight, nonlinearity='linear') + nn.init.constant_(classifier.bias, 0) + else: + classifier = nn.Linear(self.embed_dim, self.init_cls_num + self.inc_cls_num * (self._cur_task_id), bias=True) + + nn.init.kaiming_uniform_(classifier.weight, nonlinearity='linear') + nn.init.constant_(classifier.bias, 0) + + nb_output = self.classifier.out_features + classifier.weight.data[:nb_output] = copy.deepcopy(self.classifier.weight.data) + classifier.bias.data[:nb_output] = copy.deepcopy(self.classifier.bias.data) + del self.classifier + + self.classifier = classifier + + def forward(self, x, inference = False): + + features = self.backbone(x) + logits = self.classifier(features) + return logits + +class SD_LoRA(nn.Module): + + def __init__(self, backbone, device, **kwargs): + + super().__init__() + + self.device = device + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.task_num = kwargs["task_num"] + self.init_mag = kwargs['init_mag'] + self.rank_reduction = kwargs['rank_reduction'] + self.knowledge_dist = kwargs['knowledge_dist'] + self._known_classes = 0 + + self._network = Model(backbone, device, **kwargs) + self.attention_modules = [module for module in self._network.modules() if isinstance(module, MultiHeadAttention_SDLoRA)] + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + logits = self._network(x) + + # Masked previous classes + fake_y = y - self._known_classes + loss = F.cross_entropy(logits[:, self._known_classes:], fake_y) + + preds = logits.max(1)[1] + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc, loss + + def inference(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + logits = self._network(x, inference = True) + preds = logits.max(1)[1] + + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc + + @torch.no_grad() + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + self._network.update_fc() + + if self.rank_reduction[0]: + if task_idx == self.rank_reduction[1]: + for module in self.attention_modules: + module.lora_rank = self.rank_reduction[3] + + elif task_idx == self.rank_reduction[2]: + for module in self.attention_modules: + module.lora_rank = self.rank_reduction[4] + + # All blocks share same magnitude + mag = nn.ParameterList([nn.Parameter(torch.Tensor([self.init_mag])) for _ in range(task_idx + 1)]) + for module in self.attention_modules: + module.mag_lora = mag + module.init_param() + + self._network = self._network.to(self.device) + + unfrezeed_params = [] + for name, param in self._network.named_parameters(): + param.requires_grad_(False) + if f'classifier' in name or \ + f'lora' and f'list.{task_idx}' in name or \ + ('mag' in name and 'assimilated' not in name): + param.requires_grad_(True) + unfrezeed_params.append(name) + + print(f"Current task : {task_idx}, Parameters to be updated: {len(unfrezeed_params)}") + + @torch.no_grad() + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + if self.knowledge_dist[0] and task_idx > 0: + for layer, module in enumerate(self.attention_modules): + + dirs_q, dirs_v = [], [] + for i in range(len(module.lora_A_q_list)): + + norm_B = torch.norm(module.lora_B_q_list[i].weight) + norm_A = torch.norm(module.lora_A_q_list[i].weight) + + if norm_A != 0 and norm_B != 0: + dirs_q.append( + (module.lora_B_q_list[i].weight @ module.lora_A_q_list[i].weight) / (norm_B * norm_A) + ) + else: # zero-tensor, for consistency + dirs_q.append( + (module.lora_B_q_list[i].weight @ module.lora_A_q_list[i].weight) + ) + + norm_B = torch.norm(module.lora_B_v_list[i].weight) + norm_A = torch.norm(module.lora_A_v_list[i].weight) + + if norm_A != 0 and norm_B != 0: + dirs_v.append( + (module.lora_B_v_list[i].weight @ module.lora_A_v_list[i].weight) / (norm_B * norm_A) + ) + else: # zero-tensor, for consistency + dirs_v.append( + (module.lora_B_q_list[i].weight @ module.lora_A_q_list[i].weight) + ) + + flatten_dirs = [dir_q.flatten() for dir_q in dirs_q] + + last_dir = flatten_dirs[-1].unsqueeze(1) + prev_dirs = torch.stack(flatten_dirs[:-1], dim=-1) + + alphas = torch.linalg.lstsq(prev_dirs, last_dir) + + if alphas.residuals < self.knowledge_dist[1]: + print(f'Layer {layer}: {alphas.residuals.item()} < {self.knowledge_dist[1]}, Q Merged with {alphas.solution}') + + assert prev_dirs.shape[1] == len(module.assimilated_mag_lora_q) - 1 + for ii in range(prev_dirs.shape[1]): + module.assimilated_mag_lora_q[ii] += alphas.solution[i] + + nn.init.zeros_(module.lora_B_q_list[task_idx]) + nn.init.zeros_(module.lora_A_q_list[task_idx]) + + flatten_dirs = [dir_v.flatten() for dir_v in dirs_v] + + last_dir = flatten_dirs[-1].unsqueeze(1) + prev_dirs = torch.stack(flatten_dirs[:-1], dim=-1) + + alphas = torch.linalg.lstsq(prev_dirs, last_dir) + + if alphas.residuals < self.knowledge_dist[1]: + print(f'Layer {layer}: {alphas.residuals.item()} < {self.knowledge_dist[1]}, V Merged with {alphas.solution}') + + assert prev_dirs.shape[1] == len(module.assimilated_mag_lora_v) - 1 + for ii in range(prev_dirs.shape[1]): + module.assimilated_mag_lora_v[ii] += alphas.solution[i] + + nn.init.zeros_(module.lora_B_v_list[task_idx]) + nn.init.zeros_(module.lora_A_v_list[task_idx]) + + def get_parameters(self, config): + return self._network.parameters() \ No newline at end of file diff --git a/core/model/trgp.py b/core/model/trgp.py new file mode 100644 index 0000000000000000000000000000000000000000..078729f292a4bf3b694572988f125a4452ee2313 --- /dev/null +++ b/core/model/trgp.py @@ -0,0 +1,427 @@ +""" +@article{lin2022trgp, + title={TRGP: Trust Region Gradient Projection for Continual Learning}, + author={Lin, Sen and Yang, Li and Fan, Deliang and Zhang, Junshan}, + journal={arXiv preprint arXiv:2202.02931}, + year={2022} +} + +Code Reference: +https://github.com/LYang-666/TRGP +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from .backbone.alexnet import Conv2d_TRGP, Linear_TRGP, AlexNet_TRGP +from .backbone.clip import tokenize, CLIP + +Epsilon = 0.5 + +AlexNet = AlexNet_TRGP +Clip = CLIP + +class TopK: + + ''' + A class to maintain a collection of the top K items based on a specified attribute. + + This class allows for the dynamic addition of items, each represented as a dictionary, + where each dictionary must have a key 'proj_norm' that represents the value used + to determine the ranking. The class keeps track of the top K items with the highest + 'proj_norm' values. + ''' + + def __init__(self, k): + self.k = k + self.top_k_list = [] + + def add(self, dict): + if len(self.top_k_list) < self.k: + self.top_k_list.append(dict) + elif dict['proj_norm'] > min(self.top_k_list, key=lambda x: x['proj_norm'])['proj_norm']: + self.top_k_list.remove(min(self.top_k_list, key=lambda x: x['proj_norm'])) + self.top_k_list.append(dict) + + def get_top_k(self): + return self.top_k_list + +class Network(nn.Module): + def __init__(self, backbone, **kwargs): + super().__init__() + self.backbone = backbone + + self.classifiers = nn.ModuleList([ + nn.Linear(backbone.feat_dim, kwargs['init_cls_num'], bias = False)] + + [nn.Linear(backbone.feat_dim, kwargs['inc_cls_num'], bias = False) for _ in range(kwargs['task_num'] - 1)] + ) + + def return_hidden(self, data): + return self.backbone(data) + + def forward(self, data, compute_input_matrix = False): + logits = [] + image_features = self.backbone(data, compute_input_matrix) + for classifier in self.classifiers: + logits.append(classifier(image_features)) + + return logits + +class TRGP(nn.Module): + + def __init__(self, backbone, device, **kwargs): + super().__init__() + + self.backbone = backbone + self.device = device + self.task_num = kwargs["task_num"] + self.init_cls_num = kwargs["init_cls_num"] + self.inc_cls_num = kwargs["inc_cls_num"] + self.label_smoothing = kwargs['label_smoothing'] + + self._known_classes = 0 + self.feature_list = [] + self.feature_mat = [] + self.layers = [] + + if isinstance(backbone, Clip): + self.network = backbone + + self.visual_U = [] + self.lamda = [[0 for _ in range(12)] for _ in range(12)] + + self.accm_class_names = [] + self.curr_class_names = [] + self.accm_text_tokens = None + self.curr_text_tokens = None + + self.prompt_template = kwargs['prompt_template'] + + # 12 Visual Transformer's Adapter * 2 (Down & Up) + for name, module in self.network.named_modules(): + if 'visual' in name and isinstance(module, Linear_TRGP): + self.layers.append(module) + + for name, param in self.network.named_parameters(): + param.requires_grad = False + if 'adaptmlp' in name: + param.requires_grad = True + + elif isinstance(backbone, AlexNet): + self.network = Network(backbone, **kwargs) + + # # 3 Conv2d and 2 Linear + for module in self.network.modules(): # + if isinstance(module, Conv2d_TRGP) or isinstance(module, Linear_TRGP): + self.layers.append(module) + + else: + raise NotImplementedError + + self.feature_list_each_tasks = [[0 for _ in range(len(self.layers))] for _ in range(self.task_num)] + self.scale_param_each_tasks_each_layers = [[0 for _ in range(len(self.layers))] for _ in range(self.task_num)] + self.all_space = [[0 for _ in range(len(self.layers))] for _ in range(self.task_num)] + + self.network.to(self.device) + + def observe(self, data): + + x, y = data['image'].to(self.device), data['label'].to(self.device) - self._known_classes + if len(y) == 1: # Ignore batch_size == 1 + return None, 0., torch.zeros(1, requires_grad=True) + + if isinstance(self.backbone, Clip): + + features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.curr_text_tokens) + loss = F.cross_entropy(logits_per_img, y, label_smoothing=self.label_smoothing) + + preds = logits_per_img.softmax(dim=-1).argmax(dim=1) + + loss.backward() + + if self.cur_task > 0: + for i, module in enumerate(self.layers): + sz = module.weight.grad.data.shape[0] + module.weight.grad.data = module.weight.grad.data - (module.weight.grad.data.view(sz,-1) @ self.feature_mat[i]).view(module.weight.shape) + + elif isinstance(self.backbone, AlexNet): + + logits = self.network(x) + loss = F.cross_entropy(logits[self.cur_task], y, label_smoothing=self.label_smoothing) + + preds = logits[self.cur_task].max(1)[1] + + loss.backward() + + if self.cur_task > 0: + for i, module in enumerate(self.layers): + sz = module.weight.grad.data.shape[0] + module.weight.grad.data = module.weight.grad.data - (module.weight.grad.data.view(sz,-1) @ self.feature_mat[i]).view(module.weight.shape) + + else: + raise NotImplementedError + + acc = preds.eq(y).sum().item() / y.size(0) + + return preds, acc, loss + + def inference(self, data, task_id = -1): + + x, y = data['image'].to(self.device), data['label'].to(self.device) + + # Add dummy, to prevert batch_size == 1 + dummy_x = torch.randn_like(x[:1]) # only one dummy sample + x = torch.cat([x, dummy_x], dim=0) + + # Task-Aware (Task-Incremetanl Scenario) + if task_id > -1: + + if task_id == 0: + bias_classes = 0 + elif task_id == 1: + bias_classes = self.init_cls_num + else: + bias_classes = self.init_cls_num + (task_id - 1) * self.inc_cls_num + + if isinstance(self.backbone, Clip): + + for i, module in enumerate(self.layers): + module.space = self.all_space[task_id][i] + module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[task_id][i]]) + + features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.accm_text_tokens[bias_classes : self.init_cls_num + task_id * self.inc_cls_num]) + preds = logits_per_img.softmax(dim=-1).argmax(dim=1) + bias_classes + + elif isinstance(self.backbone, AlexNet): + + for i, module in enumerate(self.layers): + module.space = self.all_space[task_id][i] + module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[task_id][i]]) + + logits = self.network(x) + preds = logits[task_id].softmax(dim=-1).argmax(dim=1) + bias_classes + + else: + raise NotImplementedError + + # Task-Agnostic (Class-Incremetanl Scenario) + else: + + logits = [] + + if isinstance(self.backbone, Clip): + + for t in range(self.cur_task + 1): + for i, module in enumerate(self.layers): + + module.space = self.all_space[t][i] + module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[t][i]]) + + if t == 0: + features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.accm_text_tokens[: self.init_cls_num]) + else: + features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.accm_text_tokens[self.init_cls_num + (t-1) * self.inc_cls_num : self.init_cls_num + t * self.inc_cls_num]) + + logits.append(logits_per_img) + + elif isinstance(self.backbone, AlexNet): + + for t in range(self.cur_task + 1): + for i, module in enumerate(self.layers): + module.space = self.all_space[t][i] + module.scale_param = nn.ParameterList([nn.Parameter(scale_param) for scale_param in self.scale_param_each_tasks_each_layers[t][i]]) + + logits.append(self.network(x)[t]) + + else: + raise NotImplementedError + + preds = torch.cat(logits, dim=-1).softmax(dim=-1).argmax(dim=1) + + # Remove dummy + preds = preds[:-1] + + correct_count = preds.eq(y).sum().item() + acc = correct_count / y.size(0) + + return preds, acc + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + + # Last task have scale_param and space, need to init again + for module in self.layers: + module.disable_scale() + + self.cur_task = task_idx + + if isinstance(self.backbone, Clip): + + self.curr_class_names = train_loader.dataset.get_class_names() + self.accm_class_names += self.curr_class_names + + self.curr_text_tokens = tokenize([self.prompt_template.format(c) for c in self.curr_class_names]).to(self.device) + self.accm_text_tokens = tokenize([self.prompt_template.format(c) for c in self.accm_class_names]).to(self.device) + + if task_idx > 0: + + self.feature_mat = [torch.tensor(feat @ feat.T, dtype=torch.float32, device=self.device) for feat in self.feature_list] + optimizer = torch.optim.SGD(self.network.parameters(), lr = 0.01) # lr hardcoded + + x, y = [], [] + for batch in train_loader: + x.append(batch['image'].to(self.device)) + y.append(batch['label'].to(self.device) - self._known_classes) + + x, y = torch.cat(x, dim = 0), torch.cat(y, dim = 0) + + indices = torch.randperm(x.size(0)) + selected_indices = indices[:125] + x, y = x[selected_indices], y[selected_indices] + optimizer.zero_grad() + + if isinstance(self.backbone, Clip): + + features_img, features_txt, logits_per_img, logits_per_txt = self.network(x, self.curr_text_tokens) + loss = F.cross_entropy(logits_per_img, y) + + elif isinstance(self.backbone, AlexNet): + + logits = self.network(x) + loss = F.cross_entropy(logits[self.cur_task], y) + + loss.backward() + + for i, module in enumerate(self.layers): + + topk = TopK(2) + + grad = module.weight.grad.data.detach().cpu().numpy() + if isinstance(self.backbone, AlexNet) and isinstance(module, Conv2d_TRGP): + grad = grad.reshape(grad.shape[0], -1) + + for task_id in range(task_idx): + + proj = grad @ self.feature_list_each_tasks[task_id][i] @ self.feature_list_each_tasks[task_id][i].T + proj_norm = np.linalg.norm(proj) + + print(f'Layer {i} of {task_idx} to {task_id} : {proj_norm:.4f}/{np.linalg.norm(grad):.4f} ({proj_norm > Epsilon * np.linalg.norm(grad)})') + + if proj_norm > Epsilon * np.linalg.norm(grad): + topk.add({'proj_norm':proj_norm, 'task_id': task_id}) + + final_decision = [dic['task_id'] for dic in topk.get_top_k()] + module.enable_scale([ + torch.tensor(self.feature_list_each_tasks[task_id][i], dtype=torch.float32).to(self.device) for task_id in final_decision + ]) + print(f'Layer {i} of {task_idx} consider {final_decision} as trust region') + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + + self._known_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + + # Save the scale param + for i, module in enumerate(self.layers): + self.scale_param_each_tasks_each_layers[task_idx][i] = [scale_param.data for scale_param in module.scale_param] # top2 + self.all_space[task_idx][i] = module.space + module.disable_scale() + + x = torch.cat([batch['image'].to(self.device) for batch in train_loader], dim = 0) + + # hardcoded, choose 125 input from it + indices = torch.randperm(x.size(0)) + selected_indices = indices[:125] + x = x[selected_indices] + + self.network.eval() + + mat_list = [] # Representation / Activation of each layer + threshold = 0.97 + task_idx * 0.003 + + if isinstance(self.backbone, Clip): + + self.network(x, self.curr_text_tokens, compute_input_matrix = True) + + for module in self.layers: + + assert module.input_matrix.shape[0] == 125 + mat_list.append(module.input_matrix.view(-1, module.input_matrix.shape[-1]).detach().cpu().numpy().T) + + elif isinstance(self.backbone, AlexNet): + + self.network(x, compute_input_matrix = True) + + batch_list = [2*12,100,100] + ksize = [4, 3, 2] # kernel size of each conv layer + conv_output_size = [29, 12, 5] # output size of each conv layer + in_channel = [3, 64, 128] # input channel of each conv layer + + for i, module in enumerate(self.layers): + + if isinstance(module, Conv2d_TRGP): + bsz, ksz, s, inc = batch_list[i], ksize[i], conv_output_size[i], in_channel[i] + + mat = np.zeros((ksz * ksz * inc, s * s * bsz)) + act = module.input_matrix.detach().cpu().numpy() + + k = 0 + for kk in range(bsz): + for ii in range(s): + for jj in range(s): + mat[:,k]=act[kk, :, ii:ksz+ii, jj:ksz+jj].reshape(-1) + k += 1 + + mat_list.append(mat) + elif isinstance(module, Linear_TRGP): + mat_list.append(module.input_matrix.detach().cpu().numpy().T) + + # get the space for each layer + if task_idx == 0: + + for i, activation in enumerate(mat_list): + + U, S, _ = np.linalg.svd(activation, full_matrices = False) + # criteria (Eq-5) + sval_total = (S**2).sum() + sval_ratio = (S**2)/sval_total + r = np.sum(np.cumsum(sval_ratio) < threshold) + + self.feature_list_each_tasks[task_idx][i] = U[:, :r] + self.feature_list.append(U[:, :r]) + else: + + for i, activation in enumerate(mat_list): + + _, S, _ = np.linalg.svd(activation, full_matrices = False) + sval_total = (S**2).sum() + + delta = (self.feature_list[i].T @ activation @ activation.T @ self.feature_list[i]).diagonal() + + # following the GPM to get the sigma (S**2) + act_hat = activation - self.feature_list[i] @ self.feature_list[i].T @ activation + U, S, _ = np.linalg.svd(act_hat, full_matrices=False) + sigma = S**2 + + # stack delta and sigma, then sort in descending order + stack = np.hstack((delta, sigma)) + stack_index = np.argsort(stack)[::-1] # the index of each element in descending sorted array + stack = np.sort(stack)[::-1] # descending sorted array + + if threshold * sval_total <= 0: + r = 0 + else: + r = min(np.sum(np.cumsum(stack) < threshold * sval_total) + 1, activation.shape[0]) + + Ui = np.hstack((self.feature_list[i], U)) + sel_each = stack_index[:r] + sel_overall = sel_each[sel_each >= len(delta)] # without overlap + + self.feature_list[i] = np.hstack((self.feature_list[i], Ui[:, sel_overall])) + self.feature_list_each_tasks[task_idx][i] = Ui[:, sel_each] + + if sel_overall.shape[0] == 0: + print(f'Skip Updating Space for layer: {i+1}') + + def get_parameters(self, config): + return self.network.parameters() \ No newline at end of file diff --git a/core/model/wa.py b/core/model/wa.py new file mode 100644 index 0000000000000000000000000000000000000000..b943c599cb5a000ab05e70d6b039fe52457f96c3 --- /dev/null +++ b/core/model/wa.py @@ -0,0 +1,243 @@ +# -*- coding: utf-8 -*- +""" +@inproceedings{zhao2020maintaining, + title={Maintaining discrimination and fairness in class incremental learning}, + author={Zhao, Bowen and Xiao, Xi and Gan, Guojun and Zhang, Bin and Xia, Shu-Tao}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)}, + pages={13208--13217}, + year={2020} +} +https://arxiv.org/abs/1911.07053 + +Adapted from https://github.com/G-U-N/PyCIL/blob/master/models/wa.py, https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py. +""" + +import torch +from torch import nn +import copy +from torch.nn import functional as F +import numpy as np +from .finetune import Finetune + + +def KD_loss(pred, soft, T=2): + ''' + Code Reference: + https://github.com/G-U-N/PyCIL/blob/master/models/wa.py + + Compute the knowledge distillation loss. + + Args: + pred (torch.Tensor): Predictions of the model. + soft (torch.Tensor): Soft targets. + T (float): Temperature parameter for softening the predictions. Default is 2. + + Returns: + torch.Tensor: Knowledge distillation loss. + ''' + pred = torch.log_softmax(pred / T, dim=1) + soft = torch.softmax(soft / T, dim=1) + return -1 * torch.mul(soft, pred).sum() / pred.shape[0] + + +class IncrementalModel(nn.Module): + ''' + Code Reference: + https://github.com/G-U-N/PyCIL/blob/master/utils/inc_net.py + + A model consists with a backbone and a classifier. + + Args: + backbone (nn.Module): Backbone network. + feat_dim (int): Dimension of the extracted features. + num_class (int): Number of classes in the dataset. + ''' + def __init__(self, backbone, feat_dim, num_class): + super().__init__() + self.backbone = backbone + self.feat_dim = feat_dim + self.num_class = num_class + self.classifier = None + + def forward(self, x): + return self.get_logits(x) + + def get_logits(self, x): + ''' + Compute logits for the input data. + + Args: + x (torch.Tensor): Input data. + + Returns: + torch.Tensor: Logits of the input data. + ''' + logits = self.classifier(self.backbone(x)['features']) + return logits + + def update_classifier(self, number_classes): + ''' + Incrementally update the classifier with deepcopy. + + Args: + number_classes (int): Number of classes after update. + ''' + classifier = nn.Linear(self.feat_dim, number_classes) + if self.classifier is not None: + number_output = self.classifier.out_features + weight = copy.deepcopy(self.classifier.weight.data) + bias = copy.deepcopy(self.classifier.bias.data) + classifier.weight.data[:number_output] = weight + classifier.bias.data[:number_output] = bias + + del self.classifier + self.classifier = classifier + + def classifier_weight_align(self, incremental_number): + ''' + Align the weight of the classifier after every task. + + Args: + incremental_number (int): Number of classes added in the current task. + ''' + weights = self.classifier.weight.data + new_norm = torch.norm(weights[-incremental_number:, :], p=2, dim=1) + old_norm = torch.norm(weights[:-incremental_number, :], p=2, dim=1) + new_mean = torch.mean(new_norm) + old_mean = torch.mean(old_norm) + gamma = old_mean / new_mean + self.classifier.weight.data[-incremental_number:, :] *= gamma + + def forward(self, x): + return self.get_logits(x) + + def get_logits(self, x): + logits = self.classifier(self.backbone(x)['features']) + return logits + + def freeze(self): + ''' + Freeze the model parameters. + ''' + for param in self.parameters(): + param.requires_grad = False + self.eval() + + return self + + def extract_vector(self, x): + ''' + Extract features from the backbone network. + + Args: + x (torch.Tensor): Input data. + + Returns: + torch.Tensor: Extracted features. + ''' + return self.backbone(x)["features"] + + +class WA(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.network = IncrementalModel(self.backbone, feat_dim, kwargs['init_cls_num']) + self.device = kwargs['device'] + self.old_network = None + self.known_classes = 0 + self.total_classes = 0 + self.task_idx = 0 + # For buffer update + self.total_classes_indexes = 0 + + def observe(self, data): + ''' + Do every current task. + + Args: + data (dict): Dictionary containing input data and labels. + + Returns: + tuple: Tuple containing predictions, accuracy, and loss. + ''' + x, y = data['image'].to(self.device), data['label'].to(self.device) + + self.network.to(self.device) + if self.old_network: + self.old_network.to(self.device) + + logits = self.network(x) + loss = F.cross_entropy(logits, y) + + if self.task_idx > 0: + kd_lambda = self.known_classes / self.total_classes + loss_kd = KD_loss( + logits[:, : self.known_classes], + self.old_network(x), + ) + loss = (1 - kd_lambda) * loss + kd_lambda * loss_kd + + + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() + + return pred, acc / x.size(0), loss + + def inference(self, data): + ''' + Perform inference on the input data. + + Args: + data (dict): Dictionary containing input data and labels. + + Returns: + tuple: Tuple containing predictions and accuracy. + ''' + x, y = data['image'].to(self.device), data['label'].to(self.device) + + logits = self.network(x) + pred = torch.argmax(logits, dim=1) + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0) + + def forward(self, x): + return self.network(x) + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Do before every task for task initialization. + + Args: + task_idx (int): Index of the current task. + buffer (Buffer): Buffer object. + train_loader (DataLoader): DataLoader for training data. + test_loaders (list): List of DataLoaders for test data. + ''' + self.total_classes += self.kwargs['init_cls_num'] + self.network.update_classifier(self.total_classes) + + self.total_classes_indexes = np.arange(self.known_classes, self.total_classes) + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + ''' + Do after every task for updating the model. + + Args: + task_idx (int): Index of the current task. + buffer (Buffer): Buffer object. + train_loader (DataLoader): DataLoader for training data. + test_loaders (list): List of DataLoaders for test data. + ''' + if self.task_idx > 0: + self.network.classifier_weight_align(self.total_classes - self.known_classes) + self.old_network = copy.deepcopy(self.network).freeze() + self.known_classes = self.total_classes + + # update buffer + buffer.reduce_old_data(self.task_idx, self.total_classes) + val_transform = test_loaders[0].dataset.trfms + buffer.update(self.network, train_loader, val_transform, + self.task_idx, self.total_classes, self.total_classes_indexes, + self.device) + + self.task_idx += 1 diff --git a/core/scheduler.py b/core/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..22d72a53b37bde0257525fb9324b6b541b97c526 --- /dev/null +++ b/core/scheduler.py @@ -0,0 +1,124 @@ +from torch.optim import Optimizer +import math + +class _LRScheduler(object): + def __init__(self, optimizer, last_epoch=-1): + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + self.step(epoch = last_epoch + 1) + self.last_epoch = last_epoch + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_lr(self): + raise NotImplementedError + + def step(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = epoch + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + +class CosineSchedule(_LRScheduler): + + def __init__(self, optimizer, K): + self.K = K + super().__init__(optimizer, -1) + + def cosine(self, base_lr): + if self.K == 1: + return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (2-1))) + return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (self.K-1))) + + def get_lr(self): + return [self.cosine(base_lr) for base_lr in self.base_lrs] + + def get_last_lr(self): + return self.get_lr() + +class CosineAnnealingWarmUp(_LRScheduler): + + def __init__(self, optimizer, warmup_length, T_max = 0, last_epoch = -1): + self.warmup_length = warmup_length + self.T_max = T_max + self.last_epoch = last_epoch + + super().__init__(optimizer, last_epoch) + + def cosine_lr(self, base_lr): + + return base_lr * 0.5 * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) + + def warmup_lr(self, base_lr): + + return base_lr * (self.last_epoch + 1) / self.warmup_length + + def get_lr(self): + if self.last_epoch < self.warmup_length: + return [self.warmup_lr(base_lr) for base_lr in self.base_lrs] + else: + return [self.cosine_lr(base_lr) for base_lr in self.base_lrs] + + def get_last_lr(self): + assert self.T_max > 0, 'CosineAnnealingWarmUp is called with T_max <= 0, Check your code' + return self.get_lr() + +class PatienceSchedule(_LRScheduler): + + def __init__(self, optimizer, patience, factor): + self.factor = factor # Factor to reduce the learning rate + self.patience = patience # Number of epochs with no improvement + self.best_loss = float('inf') # Best loss seen so far + self.counter = 0 # Counter for patience + + super().__init__(optimizer, -1) + + def step(self, current_loss = None, **kwargs): + # Some scheduler step function is called with parameter epoch + # use kwargs to save it and don't do anything to it + + if current_loss is None: + return 0 + + # Check if the current loss improved + if current_loss < self.best_loss: + self.best_loss = current_loss # Update the best loss + self.counter = 0 # Reset counter since we have an improvement + else: + + self.counter += 1 # Increment counter if no improvement + + # If patience is exhausted, reduce the learning rate + if self.counter >= self.patience: + for param_group in self.optimizer.param_groups: + param_group['lr'] /= self.factor # Reduce learning rate by the factor + print(f"Reducing learning rate to {self.optimizer.param_groups[0]['lr']:.5f}") + self.counter = 0 # Reset counter after reducing learning rate + + def get_last_lr(self): + return self.optimizer.param_groups[0]['lr'] \ No newline at end of file diff --git a/core/trainer.py b/core/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..dbca6decb09f6171e88be3d80575e487555c8c19 --- /dev/null +++ b/core/trainer.py @@ -0,0 +1,720 @@ +import os +import sys +import torch + +import numpy as np +import core.model as arch +import torch.optim as optim +import torch.distributed as dist +import torch.multiprocessing as mp + +from pprint import pprint +from contextlib import redirect_stdout +from time import time +from tqdm import tqdm +from core.data import get_dataloader +from core.utils import * +from core.model.buffer import * +from core.model import bic +from torch.utils.data import DataLoader +from core.utils import Logger, fmt_date_str +from torch.optim.lr_scheduler import MultiStepLR, LambdaLR +from copy import deepcopy + +from core.scheduler import CosineSchedule, PatienceSchedule, CosineAnnealingWarmUp + +class Trainer(object): + """ + The Trainer. + + Build a trainer from config dict, set up optimizer, model, etc. + """ + + def __init__(self, rank, config): + + self.rank = rank + self.config = config + self.distribute = self.config['n_gpu'] > 1 # 暂时不考虑分布式训练 + assert not self.distribute + if self.distribute: + dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=self.config['n_gpu'], rank=rank) + self.logger = self._init_logger(config) + self.device = self._init_device(config) + + #pprint(config) + # Write config into log file only + with redirect_stdout(self.logger.file): + pprint(config) + + + self.init_cls_num, self.inc_cls_num, self.task_num = self._init_data(config) + self.model = self._init_model(config) + ( + self.train_loader, + self.test_loader, + ) = self._init_dataloader(config) + + self.buffer = self._init_buffer(config) + + self.task_idx = 0 + ( + self.init_epoch, + self.inc_epoch, + self.optimizer, + self.scheduler, + ) = self._init_optim(config) + + self.train_meter, self.test_meter = self._init_meter() + + self.val_per_epoch = config['val_per_epoch'] + + if self.config["classifier"]["name"] == "bic": + self.stage2_epoch = config['stage2_epoch'] + + def _init_logger(self, config, mode='train'): + ''' + Init logger. + + Args: + config (dict): Parsed config file. + + Returns: + logger (Logger) + ''' + + save_path = config['save_path'] + + log_path = os.path.join(save_path, "log", config['classifier']['name']) + os.makedirs(log_path, exist_ok=True) + + log_prefix = f"{config['dataset']}..{config['backbone']['name']}--ep{config['epoch']}--s{config['seed']}__{datetime.now().strftime('%Y-%m-%d_%H-%M')}" + log_file = os.path.join(log_path, f"{log_prefix}.log") + logger = Logger(log_file) + + # hack sys.stdout + sys.stdout = logger + + return logger + + def _init_device(self, config): + """" + Init the devices from the config. + + Args: + config(dict): Parsed config file. + + Returns: + device: a device. + """ + init_seed(config['seed'], config['deterministic']) + + device = torch.device(f'cuda:{config["device_ids"][self.rank]}') + torch.cuda.set_device(device) + + return device + + def _init_files(self, config): + pass + + def _init_writer(self, config): + pass + + def _init_meter(self, ): + """ + Init the AverageMeter of train/val/test stage to cal avg... of batch_time, data_time,calc_time ,loss and acc1. + + Returns: + tuple: A tuple of train_meter, val_meter, test_meter. + """ + train_meter = AverageMeter( + "train", + ["batch_time", "data_time", "calc_time", "loss", "acc1"], + ) + + test_meter = AverageMeter( + "test", + ["batch_time", "data_time", "calc_time", "acc1"], + ) + + return train_meter, test_meter + + def _init_optim(self, config): + """ + Init the optimizers and scheduler from config, if necessary, load the state dict from a checkpoint. + + Args: + config (dict): Parsed config file. + + Returns: + tuple: A tuple of optimizer, scheduler. + """ + + if 'init_epoch' in config.keys(): + init_epoch = config['init_epoch'] + else: + init_epoch = config['epoch'] + + model = self.model.module if self.distribute else self.model + + if self.task_idx == 0 and 'init_optimizer' in config.keys(): + optimizer = get_instance( + torch.optim, "init_optimizer", config, params=model.get_parameters(config) + ) + else: + optimizer = get_instance( + torch.optim, "optimizer", config, params=model.get_parameters(config) + ) + + # Check if the learning rate scheduler specified in the configuration is "CosineSchedule" + if config['lr_scheduler']['name'] == "CosineSchedule": + scheduler = CosineSchedule(optimizer, K=config['lr_scheduler']['kwargs']['K']) + elif config['lr_scheduler']['name'] == "PatienceSchedule": + scheduler = PatienceSchedule(optimizer, patience = config['lr_scheduler']['kwargs']['patience'], factor = config['lr_scheduler']['kwargs']['factor']) + elif config['lr_scheduler']['name'] == "Constant": + scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda e: 1) + elif config['lr_scheduler']['name'] == "CosineAnnealingWarmUp": + T_max = len(self.train_loader.get_loader(self.task_idx)) + T_max *= init_epoch if self.task_idx == 0 else config['epoch'] + scheduler = CosineAnnealingWarmUp(optimizer, config['lr_scheduler']['kwargs']['warmup_length'], T_max) + else: + scheduler = get_instance(torch.optim.lr_scheduler, "lr_scheduler", config, optimizer=optimizer) + + return init_epoch, config['epoch'], optimizer, scheduler + + def _init_data(self, config): + return config['init_cls_num'], config['inc_cls_num'], config['task_num'] + + def _init_model(self, config): + """ + Init model(backbone+classifier) from the config dict and load the pretrained params or resume from a + checkpoint, then parallel if necessary . + + Args: + config (dict): Parsed config file. + + Returns: + tuple: A tuple of the model and model's type. + """ + # For backward compatibility, some backbone initialization doesn't take device as argument + try: + backbone = get_instance(arch, "backbone", config, **{'device': self.device}) + except TypeError: + backbone = get_instance(arch, "backbone", config) + + model = get_instance(arch, "classifier", config, **{'device': self.device, 'backbone': backbone}).to(self.device) + + if self.distribute: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.device] + ) + + return model + + def _init_dataloader(self, config): + ''' + Init DataLoader + + Args: + config (dict): Parsed config file. + + Returns: + train_loaders (list): Each task's train dataloader. + test_loaders (list): Each task's test dataloader. + ''' + + train_loaders = get_dataloader(config, "train") + test_loaders = get_dataloader(config, "test", cls_map=train_loaders.cls_map) + + # Add DistributedSampler to each dataloader + if self.distribute: + for loaders in [train_loaders, test_loaders]: + for i, dataloader in enumerate(loaders.dataloaders): + dataset = dataloader.dataset + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) + loaders.dataloaders[i] = DataLoader( + dataset, + sampler=sampler, + batch_size=dataloader.batch_size // self.config['n_gpu'], + num_workers=dataloader.num_workers, + drop_last=dataloader.drop_last + ) + + return train_loaders, test_loaders + + def _init_buffer(self, config): + ''' + Init Buffer + + Args: + config (dict): Parsed config file. + + Returns: + buffer (Buffer): a buffer for old samples. + ''' + buffer = get_instance(arch, "buffer", config) + + return buffer + + def train_loop(self,): + """ + The norm train loop: before_task, train, test, after_task + """ + experiment_begin = time() + method_name = self.config["classifier"]["name"] + testing_times = self.config['testing_times'] + + # 记录每个 task 的 average last accuracy + batch_last_acc_list = np.zeros((self.task_num)) + task_last_acc_list = np.zeros((self.task_num)) + + # 记录每个 task 的 best last accuracy + best_batch_last_acc_list = np.zeros((self.task_num)) + best_task_last_acc_list = np.zeros((self.task_num)) + + acc_table = np.zeros((self.task_num, self.task_num)) # A numpy array with shape [task_num, task_num], where [i, j] is acc of model on task j after learning task i + bwt_list, frgt_list = [], [] + + model = self.model.module if self.distribute else self.model + + if method_name == 'RAPF': + model.model.classes_names = self.train_loader.cls_map + + for task_idx in range(self.task_num): + self.task_idx = task_idx + if self.rank == 0: + print(f"================Task {task_idx} Start!================") + + if hasattr(model, 'before_task'): + model.before_task(task_idx, self.buffer, self.train_loader.get_loader(task_idx), self.test_loader.get_loader(task_idx)) + + if self.rank == 0: + print(f"Trainable Parameters for Task {task_idx} : {count_parameters(model)} / {count_all_parameters(model)} ({count_parameters(model)*100/count_all_parameters(model):.2f}%)") + + _, _, self.optimizer, self.scheduler = self._init_optim(self.config) + dataloader = self.train_loader.get_loader(task_idx) + + if method_name == "bic": + + w_decay = 2e-4 * self.task_num / (task_idx + 1) # in source code? + self.optimizer = optim.SGD(model.get_parameters(self.config), lr = 0.1, momentum = 0.9, weight_decay = w_decay) + self.scheduler = MultiStepLR(self.optimizer, milestones = [100, 150, 200], gamma = 0.1) + + dataloader, val_bias_dataloader = self.model.spilt_and_update(dataloader, self.buffer, task_idx, self.config) + + elif isinstance(self.buffer, (LinearBuffer, LinearHerdingBuffer)) and self.buffer.buffer_size > 0 and task_idx > 0: + datasets = dataloader.dataset + if isinstance(datasets.images, list): + datasets.images.extend(self.buffer.images) + datasets.labels.extend(self.buffer.labels) + elif isinstance(datasets.images, np.ndarray): + datasets.images = np.concatenate((datasets.images, self.buffer.images), axis=0) + datasets.labels = np.concatenate((datasets.labels, self.buffer.labels), axis=0) + else: + assert 0 + + dataloader = DataLoader( + datasets, + shuffle = True, + batch_size = self.config['batch_size'], + drop_last = False, + num_workers = self.config['num_workers'] + ) + + if method_name in ["LoRAsub_DRS"]: + print('Replacing Optim & Scheduler') + self.optimizer = self.model.get_optimizer(self.config['optimizer']['kwargs']['lr'], self.config['optimizer']['kwargs']['weight_decay']) + self.scheduler = CosineSchedule(self.optimizer, K=self.config['epoch']) + + if method_name == 'CL_LoRA': + self.model.set_optim(self.optimizer) + + if self.rank == 0: + print(f"================Task {task_idx} Training!================") + print(f"The training samples number : {len(dataloader.dataset)}") + + # Reset Best Record + best_batch_last_acc, best_task_last_acc = 0., 0. + best_bwt, best_frgt = float('-inf'), float('inf') + + for epoch_idx in range(self.init_epoch if task_idx == 0 else self.inc_epoch): + if self.rank == 0: + print("================Train on train set================") + train_meter = self._train(epoch_idx, dataloader) + + acc1 = torch.tensor(train_meter.avg("acc1"), device=self.device) + loss = torch.tensor(train_meter.avg("loss"), device=self.device) + if self.distribute: + # Aggregate accuracy across all processes + dist.barrier() + dist.all_reduce(acc1, op=dist.ReduceOp.SUM) # Sum accuracy across processes + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + acc1 = acc1 / self.config['n_gpu'] # Normalize by world size + loss = loss / self.config['n_gpu'] + dist.barrier() + + acc1 = acc1.item() + loss = loss.item() + + if self.rank == 0: + print(f"Epoch [{epoch_idx}/{self.init_epoch if task_idx == 0 else self.inc_epoch}] Learning Rate {self.scheduler.get_last_lr()}\t|\tLoss: {loss:.4f} \tAverage Acc: {acc1:.2f} ") + + if (epoch_idx+1) % self.val_per_epoch == 0 or (epoch_idx+1) == self.inc_epoch: + if self.rank == 0: + print(f"================Validation on test set================") + + # Disable validation for some method + if method_name in ['TRGP', + 'RanPAC', + 'MInfLoRA2', + 'MInfLoRA3', + 'PRAKA', + 'TRGP_CLIP', + 'LoRAsub_DRS', + 'CL_LoRA' + ]: + if self.rank == 0: + print(f" * Disabled validation for this method") + else: + test_acc = self._validate(task_idx) + + batch_last_acc, per_task_acc = test_acc['avg_acc'], test_acc['per_task_acc'] + best_batch_last_acc = max(batch_last_acc, best_batch_last_acc) + + task_last_acc = np.mean(per_task_acc) + best_task_last_acc = max(task_last_acc, best_task_last_acc) + + frgt, bwt = compute_frgt(acc_table, per_task_acc, task_idx), compute_bwt(acc_table, per_task_acc, task_idx) + best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) + + if self.rank == 0: + print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") + print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") + print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") + print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") + print(f" * Per-Task Acc: {per_task_acc}") + + if self.config['lr_scheduler']['name'] == "PatienceSchedule": + self.scheduler.step(train_meter.avg('loss')) + if self.scheduler.get_last_lr() < self.config['lr_scheduler']['kwargs']['stopping_lr']: + if self.rank == 0: + print(f"{self.scheduler.get_last_lr()} < {self.config['lr_scheduler']['kwargs']['stopping_lr']}, stopping this task now") + break + else: + self.scheduler.step() + + if hasattr(model, 'after_task'): + model.after_task(task_idx, self.buffer, self.train_loader.get_loader(task_idx), self.test_loader.get_loader(task_idx)) + + # Update Buffer + if method_name not in ['bic', 'ERACE', 'ERAML']: + self.buffer.total_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num + if self.buffer.buffer_size > 0: + if self.buffer.strategy == 'herding': + herding_update(self.train_loader.get_loader(task_idx).dataset, self.buffer, model.backbone, self.device) + elif self.buffer.strategy == 'random': + random_update(self.train_loader.get_loader(task_idx).dataset, self.buffer) + elif self.buffer.strategy == 'balance_random': + balance_random_update(self.train_loader.get_loader(task_idx).dataset, self.buffer) + + # Stage 2 Training : BIC (Stage 2 start after buffer being updated) + if self.config["classifier"]["name"] == "bic" and task_idx > 0: + + bias_scheduler = optim.lr_scheduler.LambdaLR(model.bias_optimizer, lr_lambda=lambda e: 1) + + for epoch_idx in range(self.stage2_epoch): + if self.rank == 0: + print("================ Train on the train set (stage2)================") + train_meter = self.stage2_train(epoch_idx, val_bias_dataloader) + + if self.rank == 0: + print(f"Epoch [{epoch_idx}/{self.stage2_epoch}] Learning Rate {bias_scheduler.get_last_lr()}\t|\tLoss: {train_meter.avg('loss'):.4f} \tAverage Acc: {train_meter.avg('acc1'):.2f} ") + + if (epoch_idx+1) % self.val_per_epoch == 0 or (epoch_idx+1) == self.inc_epoch: + if self.rank == 0: + print("================ Test on the test set (stage2)================") + + test_acc = self._validate(task_idx) + + batch_last_acc, per_task_acc = test_acc['avg_acc'], test_acc['per_task_acc'] + best_batch_last_acc = max(batch_last_acc, best_batch_last_acc) + + task_last_acc = np.mean(per_task_acc) + best_task_last_acc = max(task_last_acc, best_task_last_acc) + + frgt, bwt = compute_frgt(acc_table, per_task_acc, task_idx), compute_bwt(acc_table, per_task_acc, task_idx) + best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) + + if self.rank == 0: + print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") + print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") + print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") + print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") + print(f" * Per-Task Acc: {per_task_acc}") + + #bias_scheduler.step() + + for test_idx in range(testing_times): + if self.rank == 0: + print(f"================Test {test_idx+1}/{testing_times} of Task {task_idx}!================") + + test_acc = self._validate(task_idx) + + batch_last_acc, per_task_acc = test_acc['avg_acc'], test_acc['per_task_acc'] + best_batch_last_acc = max(batch_last_acc, best_batch_last_acc) + + task_last_acc = np.mean(per_task_acc) + best_task_last_acc = max(task_last_acc, best_task_last_acc) + + frgt, bwt = compute_frgt(acc_table, per_task_acc, task_idx), compute_bwt(acc_table, per_task_acc, task_idx) + best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) + + if self.rank == 0: + print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") + print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") + print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") + print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") + print(f" * Per-Task Acc: {per_task_acc}") + + batch_last_acc_list[task_idx] += batch_last_acc # avg_acc_list[task_idx] += avg_acc + task_last_acc_list[task_idx] += task_last_acc + acc_table[task_idx][:task_idx + 1] += np.array(per_task_acc) + + best_batch_last_acc_list[task_idx] = best_batch_last_acc + best_task_last_acc_list[task_idx] = best_task_last_acc + + # Take mean of testing_times + batch_last_acc_list[task_idx] /= testing_times + task_last_acc_list[task_idx] /= testing_times + acc_table[task_idx] /= testing_times + + batch_last_acc = batch_last_acc_list[task_idx] + task_last_acc = task_last_acc_list[task_idx] + + frgt, bwt = compute_frgt(acc_table, acc_table[task_idx], task_idx), compute_bwt(acc_table, acc_table[task_idx], task_idx) + best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) + if task_idx > 1: + frgt_list.append(frgt) + bwt_list.append(bwt) + + if self.rank == 0: + print(f"================Result of Task {task_idx} Testing!================") + print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") + print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") + print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") + print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") + print(f" * Per-Task Acc: {acc_table[task_idx][:task_idx + 1]}") + + batch_ovr_avg_acc = np.mean(batch_last_acc_list) #batch_ovr_avg_acc = np.mean(avg_acc_list) + best_batch_ovr_avg_acc = np.mean(best_batch_last_acc_list) # best_batch_ovr_avg_acc = np.mean(best_avg_acc_list) + + task_ovr_avg_acc = np.sum(np.sum(acc_table[:task_idx + 1], axis = 1) / np.arange(1, task_idx + 2)) / (task_idx + 1) + + ovr_bwt = np.mean(bwt_list) if len(bwt_list) > 0 else float('-inf') + ovr_frgt = np.mean(frgt_list) if len(frgt_list) > 0 else float('inf') + + if self.rank == 0: + print(f"================Overall Result of {self.task_num} Tasks!================") + print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") + print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") + print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") + print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") + print(f" * [Batch] Overall Avg Acc : {batch_ovr_avg_acc:.2f} (Best: {best_batch_ovr_avg_acc:.2f})") + print(f" * [Task] Overall Avg Acc : {task_ovr_avg_acc:.2f}") + print(f" * Overall Frgt : {ovr_frgt:.3f}") + print(f" * Overall BwT : {ovr_bwt:.2f}") + print(f" * Average Acc Table : \n{acc_table}") + + print(f"================Model Performance Analysis================") + print(f" * Time Costs : {(time() - experiment_begin):.2f} sec") + fps = compute_fps(model, self.config) + avg_fps, best_fps = fps['avg_fps'], fps['best_fps'] + print(f" * Average FPS (Best FPS) : {avg_fps:.0f} ({best_fps:.0f})") + + def stage2_train(self, epoch_idx, dataloader): + """ + The stage 2 train stage of method : BIC + + Args: + epoch_idx (int): Epoch index + + Returns: + dict: {"avg_acc": float} + """ + model = self.model.module if self.distribute else self.model + + model.eval() + for layer in model.bias_layers: + layer.train() + + meter = self.train_meter + meter.reset() + + total = len(dataloader) + for b, batch in tqdm(enumerate(dataloader), total=total, disable=(self.rank != 0)): + + output, acc, loss = model.stage2(batch) + + meter.update("acc1", 100 * acc) + meter.update("loss", loss.item()) + + return meter + + def _train(self, epoch_idx, dataloader): + """ + The train stage. + + Args: + epoch_idx (int): Epoch index + + Returns: + dict: {"avg_acc": float} + """ + model = self.model.module if self.distribute else self.model + + model.train() + if self.config['classifier']['name'] == 'bic': + for layer in model.bias_layers: + layer.eval() + + meter = deepcopy(self.train_meter) + meter.reset() + + total = len(dataloader) + init_seed(self.config['seed'] + epoch_idx, self.config['deterministic']) # Ensure Reproducibility + for b, batch in tqdm(enumerate(dataloader), total=total, disable=(self.rank != 0)): + + batch['batch_id'] = b + + # These method's LR is updated every iterations, not epochs + if self.config['classifier']['name'] in ['MOE_ADAPTER4CL', 'DMNSP', 'DMNSP_CIL']: + self.scheduler.step(total * epoch_idx + b) + + if self.config["classifier"]["name"] in ['TRGP', 'DMNSP', 'DMNSP_CIL', 'TRGP_CLIP', + 'GPM', 'MoE_Test2', 'API', 'L2P']: + self.optimizer.zero_grad() + output, acc, loss = model.observe(batch) + elif self.config["classifier"]["name"] in ['bic']: + output, acc, loss = model.observe(batch) + self.optimizer.zero_grad() + loss.backward(retain_graph=True) + else: + output, acc, loss = model.observe(batch) + self.optimizer.zero_grad() + loss.backward() + + self.optimizer.step() + + if self.config["classifier"]["name"] in ['ERACE', 'ERAML']: + model.add_reservoir() + + meter.update("acc1", 100 * acc) + meter.update("loss", loss.item()) + + return meter + + def _validate(self, task_idx): + + dataloaders = self.test_loader.get_loader(task_idx) + + model = self.model.module if self.distribute else self.model + model.eval() + + if self.config["classifier"]["name"] == 'bic': + for layer in model.bias_layers: + layer.eval() + + per_task_acc = [] + count_all, correct_all = 0, 0 + + if self.config['testing_per_task']: + + count_task, correct_task = 0, 0 + + with torch.no_grad(): + for t, dataloader in enumerate(dataloaders): + correct_task, count_task = 0, 0 + + for b, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc = f"Testing on Task {t} data", disable=self.rank != 0): # Disable tqdm for non-master processes + + if self.config['setting'] == 'task-aware': + output, acc = model.inference(batch, task_id=t) + elif self.config['setting'] == 'task-agnostic': + output, acc = model.inference(batch) + + correct_task += int(acc * batch['label'].shape[0]) + count_task += batch['label'].shape[0] + + correct_all += correct_task + count_all += count_task + + if self.distribute: + pass + + per_task_acc.append(round(correct_task * 100 / count_task, 2)) + + if self.distribute: + pass + + else: + + datasets = [dl.dataset for dl in dataloaders] + + all_images = np.concatenate([ds.images for ds in datasets], axis=0) + all_labels = np.concatenate([ds.labels for ds in datasets], axis=0) + + merged_dataset = copy.deepcopy(datasets[0]) + merged_dataset.images = all_images + merged_dataset.labels = all_labels + + merged_loader = DataLoader( + merged_dataset, + shuffle = True, + batch_size = self.config['batch_size'], + drop_last = False, + num_workers = self.config['num_workers'], + pin_memory=False + ) + + class_boundaries = [] + start_cls = 0 + for t in range(task_idx + 1): + n_cls = self.init_cls_num if t == 0 else self.inc_cls_num + class_boundaries.append((start_cls, start_cls + n_cls)) + start_cls += n_cls + + correct_by_task = np.zeros(task_idx + 1, dtype=int) + count_by_task = np.zeros(task_idx + 1, dtype=int) + + # 4. 推理 + with torch.no_grad(): + for b, batch in tqdm(enumerate(merged_loader), total=len(merged_loader), desc=f"Testing merged tasks <= {task_idx}", disable=self.rank != 0): + + if self.config['setting'] == 'task-aware': + print('Mostly methods dont support this, set testing_per_task to False') + raise NotImplementedError + output, acc = model.inference(batch, task_id=None) + elif self.config['setting'] == 'task-agnostic': + output, acc = model.inference(batch) + preds = output.cpu().numpy() + + labels = batch['label'].cpu().numpy() + correct_all += np.sum(preds == labels) + + count_all += len(labels) + + # 统计每个 task 的正确率 + for t, (start, end) in enumerate(class_boundaries): + mask = (labels >= start) & (labels < end) + if np.any(mask): + correct_by_task[t] += np.sum(preds[mask] == labels[mask]) + count_by_task[t] += np.sum(mask) + + per_task_acc = [round(c * 100 / n, 2) if n > 0 else 0 for c, n in zip(correct_by_task, count_by_task)] + + avg_acc = round(correct_all * 100 / count_all, 2) + + return { + "avg_acc": avg_acc, + "per_task_acc": per_task_acc + } diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c32b4b26a7a9fb8a65b3a540dc833f1a468254a --- /dev/null +++ b/core/utils/__init__.py @@ -0,0 +1,2 @@ +from .utils import * +from .logger import * \ No newline at end of file diff --git a/core/utils/logger.py b/core/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e12168c695c5569d7b2994d1a5ebe9d19d24bcb8 --- /dev/null +++ b/core/utils/logger.py @@ -0,0 +1,37 @@ +import os +import sys + + +class Logger(object): + def __init__(self, fpath): + self.console = sys.stdout + self.file = open(fpath, "w") + + def __del__(self): + self.close() + + def __enter__(self): + pass + + def __exit__(self, *args): + self.close() + + def write(self, msg): + self.console.write(msg) + if self.file is not None: + self.file.write(msg) + else: + self.console.write("Warning: Log file is None") + + def flush(self): + self.console.flush() + if self.file is not None: + self.file.flush() + os.fsync(self.file.fileno()) + else: + self.console.write("Warning: Log file is None") + + def close(self): + self.console.close() + if self.file is not None: + self.file.close() diff --git a/core/utils/utils.py b/core/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c6cba3c568b282c3ccd72f036d9c48cc426cb053 --- /dev/null +++ b/core/utils/utils.py @@ -0,0 +1,257 @@ +import pandas as pd +import os +import torch +from datetime import datetime +import numpy as np +import random +from time import time +from torch.optim.lr_scheduler import _LRScheduler + + +class AverageMeter(object): + """ + A AverageMeter to meter avg of number-like data. + """ + + def __init__(self, name, keys, writer=None): + self.name = name + self._data = pd.DataFrame( + index=keys, columns=["last_value", "total", "counts", "average"] + ) + self.writer = writer + self.reset() + + def reset(self): + for col in self._data.columns: + self._data[col].values[:] = 0 + + def update(self, key, value, n=1): + if self.writer is not None: + tag = "{}/{}".format(self.name, key) + self.writer.add_scalar(tag, value) + # self._data.last_value[key] = value + # self._data.total[key] += value * n + # self._data.counts[key] += n + # self._data.average[key] = self._data.total[key] / self._data.counts[key] + self._data.loc[key, 'last_value'] = value + self._data.loc[key, 'total'] += value * n + self._data.loc[key, 'counts'] += n + self._data.loc[key, 'average'] = self._data.loc[key, 'total'] / self._data.loc[key, 'counts'] + + + def avg(self, key): + return self._data.average[key] + + def result(self): + return dict(self._data.average) + + def last(self, key): + return self._data.last_value[key] + + def total(self, key): + return self._data.total[key] + + + +def init_seed(seed=0, deterministic=False): + """ + + :param seed: + :param deterministic: + :return: + """ + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if deterministic: + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + else: + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + +def get_instance(module, name, config, **kwargs): + """ + A reflection function to get backbone/classifier/. + + Args: + module ([type]): Package Name. + name (str): Top level value in config dict. (backbone, classifier, etc.) + config (dict): The parsed config dict. + + Returns: + Corresponding instance. + """ + if config[name]["kwargs"] is not None: + kwargs.update(config[name]["kwargs"]) + + return getattr(module, config[name]["name"])(**kwargs) + +# https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py +class GradualWarmupScheduler(_LRScheduler): + """Gradually warm-up(increasing) learning rate in optimizer. + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. + Args: + optimizer (Optimizer): Wrapped optimizer. + multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. + total_epoch: target learning rate is reached at total_epoch, gradually + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) + """ + + def __init__(self, optimizer, config): + # if self.multiplier < 1.: + # raise ValueError('multiplier should be greater thant or equal to 1.') + self.optimizer = optimizer + self.total_epoch = config["epoch"] + self.warmup = config["warmup"] + self.after_scheduler = self.get_after_scheduler(config) + self.finish_warmup = False + super(GradualWarmupScheduler, self).__init__(optimizer) + + def get_after_scheduler(self, config): + scheduler_name = config["lr_scheduler"]["name"] + scheduler_dict = config["lr_scheduler"]["kwargs"] + + if self.warmup != 0: + if scheduler_name == "CosineAnnealingLR": + scheduler_dict["T_max"] -= self.warmup - 1 + elif scheduler_name == "MultiStepLR": + scheduler_dict["milestones"] = [ + step - self.warmup + 1 for step in scheduler_dict["milestones"] + ] + + if scheduler_name == "LambdaLR": + return torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lr_lambda=eval(config["lr_scheduler"]["kwargs"]["lr_lambda"]), + last_epoch=-1, + ) + + return getattr(torch.optim.lr_scheduler, scheduler_name)( + optimizer=self.optimizer, **scheduler_dict + ) + + def get_lr(self): + if self.last_epoch >= self.warmup - 1: + self.finish_warmup = True + return self.after_scheduler.get_last_lr() + + return [ + base_lr * float(self.last_epoch + 1) / self.warmup + for base_lr in self.base_lrs + ] + + def step_ReduceLROnPlateau(self, metrics, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = ( + epoch if epoch != 0 else 1 + ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning + if self.last_epoch <= self.total_epoch: + warmup_lr = [ + base_lr + * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) + for base_lr in self.base_lrs + ] + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): + param_group["lr"] = lr + else: + if epoch is None: + self.after_scheduler.step(metrics, None) + else: + self.after_scheduler.step(metrics, epoch - self.total_epoch) + + def step(self, epoch=None, metrics=None): + if type(self.after_scheduler) != torch.optim.lr_scheduler.ReduceLROnPlateau: + if self.finish_warmup and self.after_scheduler: + if epoch is None: + self.after_scheduler.step(None) + else: + self.after_scheduler.step(epoch - self.warmup) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(GradualWarmupScheduler, self).step(epoch) + else: + self.step_ReduceLROnPlateau(metrics, epoch) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def count_all_parameters(model): + + return sum(p.numel() for p in model.parameters()) + +def fmt_date_str(date=None, fmt="%y-%m-%d-%H-%M-%S"): + """Format date to string. + + Args: + datetime (datetime, optional): get current time if None. Defaults to None. + + Returns: + str: formatted date string + """ + if date is None: + date = datetime.now() + return date.strftime(fmt) + +def compute_bwt(acc_table, curr_acc, task_idx): + ''' + After training T tasks, $BWT = \frac{\sum_{i=3}^T\sum_{j=1}^{i-2}R_{i,j}-R{j,j}}{T(T-1)/2}$ + Equivalent to Positive BwT of Continuum + https://arxiv.org/pdf/1810.13166 + ''' + + if task_idx > 1: + + bwt = 0. + for i in range(2, task_idx): + for j in range(i - 1): + bwt += acc_table[i, j] - acc_table[j, j] + + for j in range(task_idx - 1): + bwt += curr_acc[j] - acc_table[j, j] + + return (bwt * 2) / (task_idx * (task_idx+1)) + + return 0. + + +def compute_frgt(acc_table, curr_acc, task_idx): + ''' + After training T tasks, $Frgt = \frac{\sum_{j=1}^{T-2}R_{T-1,j}-R_{j,j}}{T-1}$ + Equivalent to Forgetting of Continuum + ''' + + if task_idx > 1: + return sum(np.diag(acc_table)[:task_idx - 1] - curr_acc[:task_idx+1][:-2]) / task_idx + return 0. + + +def compute_fps(model, config): + model.eval() + + # assert image_size in config is Correct, check ranpac + + data = { + 'image' : torch.rand((2, 3, config['image_size'], config['image_size'])).cuda(), + 'label' : torch.zeros((2)) + } + + t_all = [] + + for i in range(100): + t1 = time() + if config['setting'] == 'task-aware': + model.inference(data, task_id = random.randint(0, config['task_num'] - 1)) + elif config['setting'] == 'task-agnostic': + model.inference(data) + t2 = time() + t_all.append(t2 - t1) + + return {'avg_fps' : 1 / np.mean(t_all), + 'best_fps' : 1 / min(t_all)} \ No newline at end of file diff --git a/docs/tutorials/en/add_a_new_method_en.md b/docs/tutorials/en/add_a_new_method_en.md new file mode 100644 index 0000000000000000000000000000000000000000..a414e1d4a2e941188ae1f4df24170e8ae951ccee --- /dev/null +++ b/docs/tutorials/en/add_a_new_method_en.md @@ -0,0 +1,201 @@ +# Add a new method + +Taking the [`LUCIR`](https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html) method as an example, we will describe how to add a new method. + +Before this, we need to introduce a parent class of all methods:`Finetune`. + +```python +class Finetune(nn.Module): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + ... + self.kwargs = kwargs + + def observe(self, data): + ... + return pred, acc / x.size(0), loss + + def inference(self, data): + ... + return pred, acc / x.size(0) + + def forward(self, x): + ... + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + pass + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + pass + + def get_parameters(self, config): + ... + return train_parameters +``` +The `Finetune` class includes several important interfaces that a method should have. ++ `__init__`: init func,set the initialize parameters required by the algorithm. ++ `observe`:used to be called in train phase, input a batch of training samples and return predictions, accuracy, and forward loss. ++ `inference`:used to be called in inference phase, input a batch of test samples and return the classification result and accuracy. ++ `forward`:override the forward function `forward` of `Module` in `pytorch`, return the ouput of `backbone`. ++ `before_task`:called before training starts for each task, used to adjust model structure, training parameters, etc., and requires user customization. ++ `after_task`:called after training starts for each task, used to adjust model structure, buffer, etc., and requires user customization. ++ `get_parameters`:called before training starts for each task, returns the training parameters for the current task. + + +## LUCIR + +### Build model +First, create `LUCIR` model class, add file `lucir.py` under core/model/replay/: (this code have some differences with source code) +```python +class LUCIR(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.K = kwargs['K'] + self.lw_mr = kwargs['lw_mr'] + self.ref_model = None + + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + + self.ref_model = copy.deepcopy(self.backbone) + ... + new_fc = SplitCosineLinear(in_features, out_features, self.kwargs['inc_cls_num']) + + self.loss_fn1 = nn.CosineEmbeddingLoss() + self.loss_fn2 = nn.CrossEntropyLoss() + self.loss_fn3 = nn.MarginRankingLoss(margin=self.kwargs['dist']) + ... + + self.backbone = self.backbone.to(self.device) + if self.ref_model is not None: + self.ref_model = self.ref_model.to(self.device) + + + def _init_new_fc(self, task_idx, buffer, train_loader): + if task_idx == 0: + return + ... + self.backbone.fc.fc2.weight.data = novel_embedding.to(self.device) + + def _compute_feature(self, feature_model, loader, num_samples, num_features): + ... + + + def observe(self, data): + x, y = data['image'], data['label'] + logit = self.backbone(x) + + ... + ref_outputs = self.ref_model(x) + loss = self.loss_fn1(...) * self.cur_lamda + loss += self.loss_fn2(...) + if hard_num > 0: + ... + loss += self.loss_fn3(...) * self.lw_mr + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + if self.task_idx > 0: + self.handle_ref_features.remove() + ... + + + def inference(self, data): + pass + + + def _init_optim(self, config, task_idx): + ... + tg_params =[{'params': base_params, 'lr': 0.1, 'weight_decay': 5e-4}, \ + {'params': self.backbone.fc.fc1.parameters(), 'lr': 0, 'weight_decay': 0}] + return tg_params +``` ++ In `__init__`, initialize `K, lw_mr, ref_model` required by `LUCIR`. ++ In `before_task`, according to the requirements of `LUCIR`, we update the classifier before the task starts and set different loss functions based on `task_idx`. ++ In `observe`,we use the loss function defined in `before_task` to calculate the forward loss. ++ In `after_task`, according to the `LUCIR` algorithm, some `hook` operations need to be removed. ++ In `_init_optim`, we select a subset of parameters from the entire model for training. + + +The implementation of the above interfaces is the difference between the `LUCIR` algorithm and other algorithms. Other interfaces can be left unimplemented and handled by Finetune.
+ +Note that due to the distinct operations of continual learning algorithms for the first task and subsequent tasks, `task_idx` is passed in before_task to identify the current task number.
+ + + +## Add `lucir.yaml` + +Please refer to [`config.md`](./config_file_en.md) for the meaning of each parameter +### Dataset + +```yaml +data_root: /data/fanzhichen/continual/cifar100 +image_size: 32 +save_path: ./ +init_cls_num: 50 +inc_cls_num: 10 +task_num: 6 +``` + +### Optimizer + +```yaml +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [80, 120] +``` + +### Backbone +```yaml +backbone: + name: resnet32 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + cosine_fc: True +``` + +### Buffer +`name`: `LinearBuffer` will merge the data with the current task data before the task starts.
+`strategy`:Buffer update strategy, only support `herding, random, equal_random, reservoir, None`
+```yaml +buffer: + name: LinearBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: herding # random, equal_random, reservoir, herding +``` + + +### Algorithm +`name`:which method.
+```yaml +classifier: + name: LUCIR + kwargs: + num_class: 100 + feat_dim: 512 + init_cls_num: 50 + inc_cls_num: 10 + dist: 0.5 + lamda: 5 + K: 2 + lw_mr: 1 + +``` diff --git a/docs/tutorials/en/config_file_en.md b/docs/tutorials/en/config_file_en.md new file mode 100644 index 0000000000000000000000000000000000000000..7b2dca6a2839ab7ed044b8c479f66c86416ff678 --- /dev/null +++ b/docs/tutorials/en/config_file_en.md @@ -0,0 +1,136 @@ +The path of the configuration file is as follows: + +``` +config/* +``` + +### LibContinual Configuration File Composition + +LibContinual configuration files use the `yaml` file format. Our predefined configuration files are located in `core/config/default.yaml`, and users can put custom configuration items into the `config/` directory and save them in `.yaml` format. + +Although most configurations have been pre-written in `default.yaml`, you cannot directly use the `default.yaml` configuration to run the framework. You need to define the configuration file corresponding to the method you want to run in advance. You can refer to the parameter descriptions below to write your own configuration file. + +The `config/headers` folder contains the following files: + +- `data.yaml`: Definitions related to data configuration are in this file +- `device.yaml`: Definitions related to GPU configuration items are in this file +- `model.yaml`: Definitions related to model configuration are in this file +- `optimizer.yaml`: Definitions related to the optimizer configuration are in this file + +### LibContinual Configuration Settings + +#### Data Settings + +- `data_root`: The storage path of the dataset +- `image_size`: The size of the input image +- `pin_memory`: Whether to use memory to speed up reading +- `workers`: The number of processes for parallel data reading + +```yaml +data_root: /data/cifar10/ +image_size: 32 +``` + +#### Model Settings + +`backbone`: Backbone network information used in this method + +- `name`: The name of the backbone network, which needs to correspond with the implementation in the LibContinual framework + +- `kwargs`: Parameters required by the backbone network, which need to be consistent with the naming in the code + + - `num_classes`: The total number of classes needed to be classified by the model + - `args`: Other required parameters + - `dataset`: The dataset being used, as different datasets have different backbone network implementation details + + ```yaml + backbone: + name: resnet18 + kwargs: + num_classes: 10 + args: + dataset: cifar10 + ``` + +`classifier`: Classifier information used in the method + +- `name`: The name of the classifier, which needs to be consistent with the method implementation in LibContinual + +- `kwargs`: Initialization parameters of the classifier, which need to be consistent with the names in the code implementation + + ```yaml + classifier: + name: PASS + kwargs: + num_class: 100 + feat_dim: 512 + # The following are method-related hyperparameters + feat_KD: 10.0 + proto_aug: 10.0 + temp: 0.1 + ``` + +#### Training Settings + +- `init_cls_num`: The number of training classes for the first task +- `inc_cls_num`: The number of training classes for subsequent incremental tasks +- `task_num`: The total number of tasks +- `init_epoch`: The number of training epochs for the first task +- `epoch`: The number of training epochs for incremental tasks +- `val_per_epoch`: How many epochs to test performance on the test set +- `batch_size`: Batch size during training +- `warm_up`: The number of warm-up epochs before training + +```yaml +warmup: 0 +init_cls_num: 50 +inc_cls_num: 10 +task_num: 6 +batch_size: 64 +init_epoch: 100 +epoch: 100 +val_per_epoch: 10 +``` + +#### Optimizer Settings + +- `optimizer`: Information about the optimizer used in training + - `name`: The name of the optimizer, only supports optimizers built into `Pytorch` + - `kwargs`: Parameters used by this optimizer, parameter names need to match the parameter names in Pytorch optimizers, for example + - `lr`: Learning rate of the optimizer + - `weight_decay`: Weight decay + +```yaml +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0002 +``` + +`lr_scheduler`: Learning rate adjustment strategy used in training, only supports adjustment strategies built into `Pytorch` + +- `name`: The name of the learning rate adjustment strategy +- `kwargs`: Parameters of the learning rate adjustment strategy, note that different learning rate adjustment strategies will have different parameters + +```yaml +lr_scheduler: + name: StepLR + kwargs: + step_size: 45 + gamma: 0.1 +``` + +#### Hardware Settings + +- `device_ids`: GPU IDs used +- `n_gpu`: The number of GPUs used in parallel during training, if it is `1`, it means parallel training is not used +- `deterministic`: Whether to enable `torch.backend.cudnn.benchmark` + +```yaml +device_ids: 3 +n_gpu: 1 +seed: 0 +deterministic: False +``` + diff --git a/docs/tutorials/en/data_module_en.md b/docs/tutorials/en/data_module_en.md new file mode 100644 index 0000000000000000000000000000000000000000..5f9405b93111b7fc2b3a09102fe9040570c2cb16 --- /dev/null +++ b/docs/tutorials/en/data_module_en.md @@ -0,0 +1,45 @@ +# Data Module + +## Related codes: + +``` +core/data/augments.py +core/data/dataloader.py +core/data/dataset.py +``` + +## Dataset file format + +In `LibContinual`, the dataset used has a fixed format. We read the data according to the dataset format set by most continual learning settings, such as [CIFAR-10](https://pytorch.org/vision/stable/datasets.html) and [CIFAR-100](https://pytorch.org/vision/stable/datasets.html). So we only need to download the dataset from the network and decompress it to use. If you want to use a new dataset and its data format is different from the above datasets, you need to convert it to the same format yourself. + +Like CIFAR-10, the file format of the dataset should be the same as the following example: + +``` +dataset_folder/ +├── train/ +│   ├── class_1/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +│ ├── ... +│   ├── class_10/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +├── test/ +│   ├── class_1/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +│ ├── ... +│   ├── class_10/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +``` + +The training images and test images need to be placed in the `train` and `test` folders respectively, where all images of the same category are placed in folde with the same name as the category, such as `cat` , `dog`, etc. + +## Configure Datasets + +After downloading or organizing the dataset according to the above file format, simply modify the `data_root` field in the configuration file. Note that `LibeContinual` will print the dataset folder name as the dataset name on the log. diff --git a/docs/tutorials/en/process_en.md b/docs/tutorials/en/process_en.md new file mode 100644 index 0000000000000000000000000000000000000000..7a1a343444b2cdb088e701f62c6c69123e24a56e --- /dev/null +++ b/docs/tutorials/en/process_en.md @@ -0,0 +1,139 @@ +# This section introduces the flow control of the code. + +The flow control process involves the following files: +- `run_trainer.py`: The outermost entry point of the program. +- `trainer.py`: The implementation file of the Trainer class, used to implement the training process of the model. +- `model.py`: The model file located in the `./core/model` folder, used to implement specific algorithm models. + +## Entry Point + +At the very beginning, the outermost logic execution of the code is in `run_trainer.py`. In this file, we initialize the trainer module and call its `train_loop` method to start the entire training process of the algorithm. + +```python +# Initialization and calling of Trainer in run_trainer.py +trainer = Trainer(rank, config) +trainer.train_loop() +``` + +As follow, we will introduce [Initialization](#Initialization), [Loop control](#loop-control), [Task preprocessing](#task-preprocessing), [Model training](#model-training), [Post-task processing](#task-post-processing) and [Evaluation](#evaluation-process). + + + +## Initialization + +After the above initialization, we will get a `trainer` class. By calling the relevant methods of this class, we proceed with the subsequent model training. + +```python +class Trainer(object): + """ + The Trainer. + Build a trainer from config dict, set up optimizer, model, etc. + """ + def __init__(self, rank, config): + # initialize the Trainer + pass +``` +During the initialization process of the trainer, we mainly initialize parameters such as the number of tasks, training rounds, training devices, log files, and result storage containers. For methods that require replay, we also initialize a buffer size. For methods that do not require replay, we initialize it to 0. In addition to initializing these necessary parameters, we also initialize the partitioning of the training and testing sets through the init_dataloader method. The meanings of the variables involved in this process are as follows: +- `config`: Save model related configuration parameters +- `Logger`: Storage of model logs +- `device`: specifies the device for model training +- ` init_data `: Set relevant data partitioning +- `model`: Save the model +- `buffer`: Possible memory replay +- `*meter`: Save relevant evaluation data + +After the above initialization, we will obtain a `trainer` class, which can be used for subsequent model training by calling its related methods. + +## Loop Control + +After completing initialization, start the training process of the model by calling the `train_loop` method of `trainer`: +```python +class Trainer(object): + def train_loop(self,): + """ + The norm train loop: before_task, train, test, after_task + """ + pass +``` +In this process, the first step is to call the model's [Task Preprocessing](#Task-Preprocessing), followed by [model training](#model-training). After the model training is completed, the model's [post task processing](#task-post-processing) is also called, and finally, [model evaluation](#evaluation-process) is performed. The following will further describe these processes. + +## Task Preprocessing + +In the task preprocessing process, the model will undergo some processing that may not be strongly related to model parameter optimization. For example, dynamically expanding related methods can initialize the network parameters that need to be expanded before the task. The specific implementation needs to be realized in the `before_task` method of each model file in the `model` module: + +```python +# An example from the `./core/model/replay/finetune.py` file is shown below +class Finetune(nn.Module): + def before_task(self, task_idx, buffer, train_loader, test_loaders): + pass +``` + +## Model Training + +Model training optimization is implemented through the `observe` method: + +```python +class Trainer(object): + def _train(self, epoch_idx, dataloader): + ... + output, acc, loss = self.model.observe(batch) + ... +``` +The method takes a batch of data and returns the logits, training accuracy, and training loss of the model's output. The model parameters are optimized by backpropagating through this loss. The specific implementation can refer to the content in `./core/model/replay/finetune.py`: +```python +# An example from the `./core/model/replay/finetune.py` file is shown below +class Finetune(nn.Module): + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logit = self.classifier(self.backbone(x)['features']) + loss = self.loss_fn(logit, y) + pred = torch.argmax(logit, dim=1) + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss +``` + +## Task Post-processing + +Similar to task preprocessing, task post-processing is used for some operations that may not be strongly related to model parameter optimization. For example, the method of re-issuing can update the replay memory in the post-task processing. The specific implementation needs to be realized in the `after_task` method of each model file in the `model` module: + +```python +# An example from the `./core/model/replay/finetune.py` file is shown below +class Finetune(nn.Module): + def after_task(self, task_idx, buffer, train_loader, test_loaders): + pass +``` + +In addition, apart from some special operations, most operations that are not strongly related to model optimization can be processed either before or after the task, with the same effect. + +## Evaluation Process + +During the training process, the model's loss, training accuracy, and other metrics are saved in the `train_meter` for analysis: + +```python +class Trainer(object): + def train_loop(self,): + ... + train_meter = self._train(epoch_idx, dataloader) + ... +``` + +In the evaluation phase of the model, the model is frozen and evaluated on the test set, and the results are saved in the `test_meter`. This is specifically implemented through the `_validate` method: + +```python +class Trainer(object): + def _validate(self, task_idx): + dataloaders = self.test_loader.get_loader(task_idx) + self.model.eval() + meter = self.test_meter + per_task_acc = [] + with torch.no_grad(): + for t, dataloader in enumerate(dataloaders): + meter[t].reset() + for batch_idx, batch in enumerate(dataloader): + output, acc = self.model.inference(batch) + meter[t].update("acc1", acc) + per_task_acc.append(round(meter[t].avg("acc1"), 2)) + return {"avg_acc": np.mean(per_task_acc), "per_task_acc": per_task_acc} +``` \ No newline at end of file diff --git a/docs/tutorials/en/write_a_config_yaml_en.md b/docs/tutorials/en/write_a_config_yaml_en.md new file mode 100644 index 0000000000000000000000000000000000000000..b1e518e69e79b446d3d8cf8164a78f6c80dd166a --- /dev/null +++ b/docs/tutorials/en/write_a_config_yaml_en.md @@ -0,0 +1,154 @@ +# Write a `.yaml` configuration file + +Code for this section: +``` +core/config/config.py +config/* +``` + +## Composition of the configuration file in LibContinual + +The configuration file of LibContinual uses a yaml format file and it also supports reading the global configuration changes from the command line. We have pre-defined a default configuration `core/config/default.yaml`. The users can put the custom configuration into the `config/` directory, and save this file in the `yaml` format. At parsing, the sequencing relationship of defining the configuration of the method is `default.yaml->config/->console`. The latter definition overrides the same value in the former definition. + +Although most of the basic configurations have been set in the `default.yaml`, you can not directly run a program just using the `default.yaml`. Before running the code, the users are required to define a configuration file of one method that has been implemented in LibContinual in the `config/` directory. + +Considering that CL menthods usually have some basic parameters, such as `image_size` or `device id`, which are often needed to be changed, LibContinual also supports making changes to some simple configurations on the command line without modifying the `yaml` file. Similarly, during training and test, because many parameters are the same of different methods, we wrap these same parameters together and put them into the`config/headers` for brevity. In this way, we can write the `yaml` files of the custom methods succinctly by importing them. + +The following is the composition of the files in the `config/headers` directory. + +- `data.yaml`: The relevant configuration of the data is defined in this file. +- `device.yaml`: The relevant configuration of GPU is defined in this file. +- `model.yaml`: The relevant configuration of the model is defined in this file. +- `optimizer.yaml`: The relevant configuration of the optimizer used for training is defined in this file. + +## The settings of the configuration file in LibContinual + +The following details each part of the configuration file and explain how to write them. An example of how the bic method is configured is also presented. + +### The settings for data + ++ `data_root`: The storage path of the dataset. + ++ `image_size`: The size of the input image. + ++ `pin_memory`: Whether to use memory acceleration for reading. + ++ `augment`: Whether to use data augmentation. + ++ `init_cls_num`: Initial number of classes. + ++ `inc_cls_num`: Incremental number of classes. + ++ `task_num`: Number of tasks. + ++ `works`: Number of working threads for data loading and preprocessing. + + ```yaml + data_root: /data/cifar100 + image_size: 84 + pin_memory: False + augment: True + init_cls_num: 20 + inc_cls_num: 20 + task_num: 5 + works: 8 + ``` + +### The settings for model + ++ `backbone`: The `backbone` information used in the method. + + `name`: The name of the `backbone`, needs to match the case of the `backbone` implemented in LibContinual. + + `kwargs`: The parameters used in the `backbone`, must keep the name consistent with the name in the code. + + `num_classes`: Number of classes. + + `args`: Other parameters, for example, the dataset used. + +```yaml + backbone: + name: resnet18 + kwargs: + num_classes: 100 + args: + dataset: cifar100 +``` + ++ `classifier`: The `classifier` information used in the method. + + `name`: The name of the `classifier`, needs to match the case of the `classifier` implemented in LibContinual. + ++ `kwargs`: The parameters used in the `classifier` initialization, must keep the name consistent with the name in the code. + + + `feat_dim`: Dimension settings + + ```yaml + classifier: + name: bic + kwargs: + feat_dim: 512 + ``` + +### The settings for training + ++ `epoch`: The number of `epoch` during training. + ++ `test_epoch`: The number of `epoch` during testing. + ++ `val_per_epoch`: The number of `epoch` in each verification phase. + ++ `stage2_epoch`: The number of `epoch` for strategy 2. + ++ `batch_size`: The batch size for training. + + ```yaml + epoch: 50 + test_epoch: 5 + val_per_epoch: 5 + stage2_epoch: 100 + batch_size: 128 + ``` + +### The settings for optimizer + ++ `optimizer`: Optimizer information used during training. + + + `name`: The name of the Optimizer, only temporarily supports all Optimizers provided by `PyTorch`. + + `kwargs`: The parameters used in the optimizer, and the name needs to be the same as the parameter name required by the `PyTorch` optimizer. + + `other`: Currently, the framework only supports the learning rate used by each part of a separately specified method, and the name needs to be the same as the variable name used in the method. + + ```yaml + optimizer: + name: SGD + kwargs: + lr: 0.01 + weight_decay: 2e-4 + momentum: 0.9 + ``` + ++ `lr_scheduler`: The learning rate adjustment strategy used during training, only temporarily supports all the learning rate adjustment strategies provided by `PyTorch`. + + `name`: The name of the learning rate adjustment strategy. + + `kwargs`: Other parameters used in the learning rate adjustment strategy in `PyTorch`. + + ```yaml + lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [25, 50] + ``` + +### The settings for Hardware + ++ `device_ids`: The `gpu` number, which is the same as the `nvidia-smi` command. + ++ `n_gpu`: The number of parallel `gpu` used during training, if `1`, it can't apply to parallel training. + ++ `seed`: Seed points used in `numpy`,`torch`,and `cuda`. + ++ `deterministic`: Whether to turn on `torch.backend.cudnn.benchmark` and `torch.backend.cudnn.deterministic` and whether to determine random seeds during training. + + ```yaml + device_ids: 0,1,2,3,4,5,6,7 + n_gpu: 4 + seed: 1993 + deterministic: False + ``` + + diff --git a/docs/tutorials/install.md b/docs/tutorials/install.md new file mode 100644 index 0000000000000000000000000000000000000000..7caad7d9a9eec89d10243da278acd11083293189 --- /dev/null +++ b/docs/tutorials/install.md @@ -0,0 +1,59 @@ +## Installation + +This section provides a tutorial on building a working environment for `LibContinual` from scratch. + +## Get the `LibContinual` library + +Use the following command to get `LibContinual`: + +```shell +cd ~ +git clone https://github.com/RL-VIG/LibContinual.git +``` + +## Configure the `LibContinual` environment + +The environment can be configured in any of the following ways: + +1. conda(recommend) + ```shell + cd # cd in `LibContinual` directory + conda env create -f requirements.yaml + ``` + +2. pip + ```shell + cd # cd in `LibContinual` directory + pip install -r requirements.txt + ``` +3. or whatever works for you as long as the following package version conditions are meet: + ``` + diffdist==0.1 + numpy==1.21.5 + pandas==1.1.5 + Pillow==9.2.0 + PyYAML==6.0.1 + scikit_learn==1.0.2 + torch==1.12.1 + torchvision==0.13.1 + tqdm==4.64.1 + python==3.8.0 + timm=0.6.7 + ``` + +## Test the installation + + +1. set the `config` as follows in `run_trainer.py`: + ```python + config = Config("./config/lucir.yaml").get_config_dict() + ``` +2. modify `data_root` in `config/lucir.yaml` to the path of the dataset to be used. +3. run code + ```shell + python run_trainer.py + ``` +4. If the first output is correct, it means that `LibContinual` has been successfully installed. + +## Next + diff --git a/docs/tutorials/zh/add_a_new_method.md b/docs/tutorials/zh/add_a_new_method.md new file mode 100644 index 0000000000000000000000000000000000000000..8154044992ce44d367c06b0e8e1a8ca21adb111d --- /dev/null +++ b/docs/tutorials/zh/add_a_new_method.md @@ -0,0 +1,198 @@ +# Add a new method + +下面以[`LUCIR`](https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html)方法为例,描述如何添加一种新的方 +法。
+ +首先,所有方法都继承同一父类`Finetune`。 + +```python +class Finetune(nn.Module): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + ... + self.kwargs = kwargs + + def observe(self, data): + ... + return pred, acc / x.size(0), loss + + def inference(self, data): + ... + return pred, acc / x.size(0) + + def forward(self, x): + ... + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + pass + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + pass + + def get_parameters(self, config): + ... + return train_parameters +``` +`Finetune`类包含了一个方法需要具备的几个重要接口: ++ `__init__`:初始化函数,用于初始化各方法需要的参数。 ++ `observe`:用于训练阶段调用,输入一个batch的训练样本,返回预测、准确率以及前向损失。 ++ `inference`:用于推理阶段调用,输入一个batch的样本,返回分类输出、准确率。 ++ `forward`:重写`pytorch`的`Module`中的`forward`函数,返回`backbone`的输出。 ++ `before_task`:在每个任务开始训练前调用,用于对模型结构、训练参数等进行调整,需要用户自定义。 ++ `after_task`:在每个任务开始训练后调用,用于对模型结构、训练参数等进行调整,需要用户自定义。 ++ `get_parameters`:在每个任务开始训练前调用,返回当前任务的训练参数。 + + +## LUCIR + +### 建立模型 +首先在`core/model/replay`下添加`lucir.py`文件:(此处省略部分源码) +```python +class LUCIR(Finetune): + def __init__(self, backbone, feat_dim, num_class, **kwargs): + super().__init__(backbone, feat_dim, num_class, **kwargs) + self.kwargs = kwargs + self.K = kwargs['K'] + self.lw_mr = kwargs['lw_mr'] + self.ref_model = None + + + def before_task(self, task_idx, buffer, train_loader, test_loaders): + self.task_idx = task_idx + + self.ref_model = copy.deepcopy(self.backbone) + ... + new_fc = SplitCosineLinear(in_features, out_features, self.kwargs['inc_cls_num']) + + self.loss_fn1 = nn.CosineEmbeddingLoss() + self.loss_fn2 = nn.CrossEntropyLoss() + self.loss_fn3 = nn.MarginRankingLoss(margin=self.kwargs['dist']) + ... + + self.backbone = self.backbone.to(self.device) + if self.ref_model is not None: + self.ref_model = self.ref_model.to(self.device) + + + def _init_new_fc(self, task_idx, buffer, train_loader): + if task_idx == 0: + return + ... + self.backbone.fc.fc2.weight.data = novel_embedding.to(self.device) + + def _compute_feature(self, feature_model, loader, num_samples, num_features): + ... + + + def observe(self, data): + x, y = data['image'], data['label'] + logit = self.backbone(x) + + ... + ref_outputs = self.ref_model(x) + loss = self.loss_fn1(...) * self.cur_lamda + loss += self.loss_fn2(...) + if hard_num > 0: + ... + loss += self.loss_fn3(...) * self.lw_mr + + pred = torch.argmax(logit, dim=1) + + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss + + def after_task(self, task_idx, buffer, train_loader, test_loaders): + if self.task_idx > 0: + self.handle_ref_features.remove() + ... + + + def inference(self, data): + pass + + + def _init_optim(self, config, task_idx): + ... + tg_params =[{'params': base_params, 'lr': 0.1, 'weight_decay': 5e-4}, \ + {'params': self.backbone.fc.fc1.parameters(), 'lr': 0, 'weight_decay': 0}] + return tg_params +``` ++ 在`__init__`中,对`LUCIR`所需要的参数`K, lw_mr, ref_model`进行初始化。 ++ 在`before_task`中,根据`LUCIR`的需要,我们在任务开始前对新旧分类头进行更新,并根据`task_idx`设置不同的损失函数 。 ++ 在`observe`中,我们实现了训练阶段中`LUCIR`的训练算法,根据`task_idx`采用不同的训练方法对模型进行训练。 ++ 在`after_task`中,根据`LUCIR`算法需要移除一些`hook`操作。 ++ 在`_init_optim`中,我们完成了对于训练参数的选择。 + +以上几个接口的实现是`LUCIR`算法与其他算法的不同点,其他接口无特殊处理可以不实现交由`Finetune`实现
+注意,由于持续学习算法对于第一个任务和其他任务有不同的操作,在`before_task`会传入`task_idx`来标识当前是第几个任务。
+ + + + +## 新增lucir.yaml文件 +各参数含义请参考['config.md'](./config_file_zh.md) +### 数据划分相关参数 +```yaml +data_root: /data/fanzhichen/continual/cifar100 +image_size: 32 +save_path: ./ +init_cls_num: 50 +inc_cls_num: 10 +task_num: 6 +``` + +### 训练优化器相关参数 +```yaml +optimizer: + name: SGD + kwargs: + lr: 0.1 + momentum: 0.9 + weight_decay: 0.0005 + +lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [80, 120] +``` + +### backbone相关参数 +```yaml +backbone: + name: resnet32 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + cosine_fc: True +``` + +### buffer相关参数 +`name`: 选择`LinearBuffer`, 会将数据在任务开始前与当前任务数据合并在一起。
+`strategy`:选择`herding`更新策略,目前可支持`random`,`equal_random`,`reservoir`,`herding`,`None`
+```yaml +buffer: + name: LinearBuffer + kwargs: + buffer_size: 2000 + batch_size: 128 + strategy: herding # random, equal_random, reservoir, herding +``` + + +### 算法相关参数 +`name`:此处标识所采用何种算法 +```yaml +classifier: + name: LUCIR + kwargs: + num_class: 100 + feat_dim: 512 + init_cls_num: 50 + inc_cls_num: 10 + dist: 0.5 + lamda: 5 + K: 2 + lw_mr: 1 + +``` diff --git a/docs/tutorials/zh/config_file_zh.md b/docs/tutorials/zh/config_file_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..292054216cc1f6ad10824a4f4bb5dcd6d9c4ff79 --- /dev/null +++ b/docs/tutorials/zh/config_file_zh.md @@ -0,0 +1,138 @@ + + +配置文件的路径如下: + +```` +config/* +```` + +### LibContinual配置文件构成 + +LibContinual配置文件使用`yaml`文件格式。我们预定义的配置文件位于`core/config/default.yaml`,用户可以将自定义的配置项放入`config/`目录下,并且保存为`.yaml`格式。 + +虽然大多数配置已经在`default.yaml`提前编写好了,但是您不能直接使用`default.yaml`配置来运行框架,需要预先定义所运行方法对应的配置文件。可以参考下面的参数说明编写你自己的配置文件。 + +在`config/headers`文件夹中,包含了以下文件: + +- `data.yaml`:数据相关的配置定义在此文件中 +- `device.yaml`:与GPU相关的配置项定义在此文件中 +- `model.yaml`:与模型相关的配置定义此文件中 +- `optimizer.yaml`:与优化器相关的配置定义在此文件中 + +### LibContinual配置文件的设置 + +#### 数据设置 + +- `data_root`:数据集的存储路径 +- `image_size`:输入图片的大小 +- `pin_momery`:是否使用内存来加速读取 +- `workers`:并行读取数据进程的数量 + +```yaml +data_root: /data/cifar10/ +image_size: 32 +``` + +#### 模型设置 + +`backbone`:该方法中使用的骨干网络信息 + +- `name`: 骨干网络的名称,需要与LibContinual框架中的实现所对应 + +- `kwargs`:骨干网络所需要的参数,需要与代码中的命名一致 + + - `num_classes`:模型需要的分类总数 + - `args`:需要的其他参数 + - `dataset`:所使用的数据集,不同数据集的骨干网络实现细节有所不同 + + ```yaml + backbone: + name: resnet18 + kwargs: + num_classes: 10 + args: + dataset: cifar10 + ``` + +`classifier`:方法中使用的分类器信息 + +- `name`:分类器的名称,需要与LibContinual中的方法实现保持一致 + +- `kwargs`:分类器的初始化参数,需要与代码实现的名称保持一致 + + ```yaml + classifier: + name: PASS + kwargs: + num_class: 100 + feat_dim: 512 + # 下面是方法相关的超参数 + feat_KD: 10.0 + proto_aug: 10.0 + temp : 0.1 + ``` + +#### 训练设置 + +- `init_cls_num`:第一个任务的训练类别数 +- `inc_cls_num`:随后增量任务的训练类别数 +- `task_num`:任务总数 +- `init_epoch`:第一个任务上的训练轮数 +- `epoch`:增量任务上的训练轮数 +- `val_per_epoch`:每过多少轮训练在测试集上测试性能 +- `batch_size`:训练时的批次大小 +- `warm_up`:训练之前的预热轮次 + +```yaml +warmup: 0 +init_cls_num: 50 +inc_cls_num: 10 +task_num: 6 +batch_size: 64 +init_epoch: 100 +epoch: 100 +val_per_epoch: 10 +``` + +#### 优化器设置 + +- `optimizer`:训练中使用的优化器信息 + - `name`:优化器的名称,只支持`Pytorch`内置的优化器 + - `kwargs`:该优化器使用的参数,参数名称需要与Pytorch中优化器参数的参数名称相同,例如 + - `lr`:优化器学习率 + - `weight_decay`:权重衰减 + +```yaml +optimizer: + name: Adam + kwargs: + lr: 0.001 + weight_decay: 0.0002 +``` + +`lr_scheduler`:训练中使用的学习率调整策略,只支持`Pytorch`内置的优化器调整策略 + +- `name`:学习率调整策略的名称 +- `kwargs`:学习率调整策略的参数,注意不同的学习率调整策略会有不同的参数 + +```yaml +lr_scheduler: + name: StepLR + kwargs: + step_size: 45 + gamma: 0.1 +``` + +#### 硬件设置 + +- `device_ids`:所使用的GPU编号 +- `n_gpu`:训练中使用的并行GPU数量, 如果是`1`, 表示不使用并行训练 +- `deterministic`:是否开启 `torch.backend.cudnn.benchmark` 和 `torch.backend.cudnn.deterministic` +- `seed`:在 `numpy`,`torch`和 `cuda`中使用的随机种子 + +```yaml +device_ids: 3 +n_gpu: 1 +seed: 0 +deterministic: False +``` diff --git a/docs/tutorials/zh/data_module_zh.md b/docs/tutorials/zh/data_module_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..0e3c5d6b72c566a82eca1bf17b9f8c34a6a96165 --- /dev/null +++ b/docs/tutorials/zh/data_module_zh.md @@ -0,0 +1,45 @@ +# 数据模块 + +## 本节相关代码: + +``` +core/data/augments.py +core/data/dataloader.py +core/data/dataset.py +``` + +## 数据集格式 + +在`LibContinual`中,所用的数据集有固定的格式。我们按照大多数持续学习设置下的数据集格式进行数据的读取,例如 [CIFAR-10](https://pytorch.org/vision/stable/datasets.html) 和 [CIFAR-100](https://pytorch.org/vision/stable/datasets.html) ,因此只需从网络上下载数据集并解压就可以使用。如果你想要使用一个新的数据集,并且该数据集的数据形式与以上数据集不同,那么你需要自己动手将其转换成相同的格式。 + +与 CIFAR-10 一样,数据集的格式应该和下面的示例一样: + +``` +dataset_folder/ +├── train/ +│   ├── class_1/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +│ ├── ... +│   ├── class_10/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +├── test/ +│   ├── class_1/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +│ ├── ... +│   ├── class_10/ +│      ├── image_1.png +│ ├── ... +│      └── image_5000.png +``` + +训练图像、测试图像需要分别放置在`train`和`test`文件夹下,其中同一类别所有图像放置在与类别同名文件夹中,例如`cat`、`dog`等。 + +## 配置数据集 + +当下载好或按照上述格式整理好数据集后,只需要在配置文件中修改`data_root`字段即可,注意`LibeContinual`会将数据集文件夹名当作数据集名称打印在log上。 diff --git a/docs/tutorials/zh/process_zh.md b/docs/tutorials/zh/process_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..67e788959fb49bc03f3fd52a97b9c960a493d39a --- /dev/null +++ b/docs/tutorials/zh/process_zh.md @@ -0,0 +1,119 @@ +# 本节介绍代码的流程控制 + +流程控制过程涉及以下几个文件: +- `run_trainer.py`: 程序最外层的入口 +- `trainer.py`: `Trainer`类别的实现文件,用来实现模型的训练流程 +- `model.py`:位于`./core/model`文件夹下的模型文件,用于实现具体算法模型 + +## 入口 +首先,代码执行逻辑的最外层是`run_trainer.py`,在这个文件中,我们通过初始化`trainer`模块后,调用它的`train_loop`方法开启真个算法的训练流程。 +```python +# run_trainer.py中初始化和调用Trainer +trainer = Trainer(rank, config) +trainer.train_loop() +``` +以下我们将从[初始化](#初始化)、[循环控制](#循环控制)、[任务前处理](#任务前处理)、[模型训练](#模型训练)、[任务后处理](#任务后处理)、[评估流程](#评估流程)几个方面展开说明。 + +## 初始化 +首先,需要对训练器进行初始化,初始化的实现代码位于`trainer.py`文件中: +```python +class Trainer(object): + """ + The Trainer. + Build a trainer from config dict, set up optimizer, model, etc. + """ + def __init__(self, rank, config): + # initialize the Trainer + pass +``` +在训练器初始化的过程中,我们主要是初始化任务数量、训练轮次、训练设备、日志文件、结果存储容器等参数,需要重放的方法还要初始化一个buffer大小,对于不需要重放的方法就初始化为0。除了初始化这些必备的参数之外,还通过_init_dataloader方法初始化训练集和测试集的划分。这一过程中涉及到的变量含义如下: +- `config`: 保存模型相关的配置参数 +- `logger`: 模型的日志存储 +- `device`: 指定模型训练的设备 +- `_init_data`: 设置相关的数据划分 +- `model`: 保存模型 +- `buffer`: 可能存在重放内存 +- `*meter`: 保存相关的评估数据 + +经过以上的初始化,会得到一个`trainer`类,通过调用这个类的相关方法进行后面的模型训练。 + +## 循环控制 +在完成初始化之后,通过调用`trainer`的`train_loop`方法开始模型的训练流程: +```python +class Trainer(object): + def train_loop(self,): + """ + The norm train loop: before_task, train, test, after_task + """ + pass +``` +在这个过程中,首先对调用模型的[任务前处理](#任务前处理),而后进行[模型训练](#模型训练),在模型训练结束后还要调用模型的[任务后处理](#任务后处理),最后进行[模型的评估](#评估流程)。下面将对这些过程进行进一步描述。 + +## 任务前处理 +在任务前处理过程中,模型将进行一个和模型参数优化可能并没有强相关的一些处理。比如,动态拓展相关的方法,可以在任务前初始化需要拓展的网络参数。具体的实现,需要在model模块中,每个模型文件各自的before_task方法下实现: +```python +# 以./core/model/replay/finetune.py文件为例展示 +class Finetune(nn.Module): + def before_task(self, task_idx, buffer, train_loader, test_loaders): + pass +``` + +## 模型训练 +模型训练优化通过observe方法实现: +```python +class Trainer(object): + def _train(self, epoch_idx, dataloader): + ... + output, acc, loss = self.model.observe(batch) + ... +``` +该方法输入一个batch的数据会返回模型输出的logits、训练精度、训练损失,通过对这个损失进行反向回传来优化模型参数。方法的具体实现可以参考`./core/model/replay/finetune.py`中的内容: +```python +# 以./core/model/replay/finetune.py文件为例展示 +class Finetune(nn.Module): + def observe(self, data): + x, y = data['image'], data['label'] + x = x.to(self.device) + y = y.to(self.device) + logit = self.classifier(self.backbone(x)['features']) + loss = self.loss_fn(logit, y) + pred = torch.argmax(logit, dim=1) + acc = torch.sum(pred == y).item() + return pred, acc / x.size(0), loss +``` +## 任务后处理 +和任务前处理相似,任务后处理用来进行一些和模型参数优化可能并不强相关的一些操作。比如,重放的方法可以在任务后处理中更新重放的内存。具体实现在每个模型文件的after_task方法中实现。具体的实现,需要在model模块中,每个模型文件各自的after_task方法下实现: +```python +# 以./core/model/replay/finetune.py文件为例展示 +class Finetune(nn.Module): + def after_task(self, task_idx, buffer, train_loader, test_loaders): + pass +``` +此外,除了一些较为特别的操作,大部分与模型优化不强相关的操作既可以放在任务前也可以放在任务后进行处理,效果是一样的。 + +## 评估流程 +在训练过程中,模型的损失、训练精度等指标会被保存到`train_meter`中,用于分析: +```python +class Trainer(object): + def train_loop(self,): + ... + train_meter = self._train(epoch_idx, dataloader) + ... +``` +在模型的评估阶段,将模型冻结后在测试集上评估,并将结果保存到`test_meter`中,具体通过`_validate`方法实现: +```python +class Trainer(object): + def _validate(self, task_idx): + dataloaders = self.test_loader.get_loader(task_idx) + self.model.eval() + meter = self.test_meter + per_task_acc = [] + with torch.no_grad(): + for t, dataloader in enumerate(dataloaders): + meter[t].reset() + for batch_idx, batch in enumerate(dataloader): + output, acc = self.model.inference(batch) + meter[t].update("acc1", acc) + per_task_acc.append(round(meter[t].avg("acc1"), 2)) + return {"avg_acc" : np.mean(per_task_acc), "per_task_acc" : per_task_acc} +``` \ No newline at end of file diff --git a/docs/tutorials/zh/write_a_config_yaml_zh.md b/docs/tutorials/zh/write_a_config_yaml_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..071574ef3c7b2d9d53b938402407d2b21a9e7cae --- /dev/null +++ b/docs/tutorials/zh/write_a_config_yaml_zh.md @@ -0,0 +1,152 @@ +# 编写`.yaml`配置文件 + +本节相关代码: +``` +core/config/config.py +config/* +``` + +## LibContinual中配置文件的组成 + +LibContinual的配置文件采用了yaml格式的文件,同时也支持从命令行中读取一些全局配置的更改。我们预先定义了一个默认的配置`core/config/default.yaml`。用户可以将自定义的配置放在`config/`目录下,保存为`yaml`格式的文件。配置定义在解析时的优先级顺序是`default.yaml->config/->console`。后一个定义会覆盖前一个定义中名称相同的值。 + +尽管`default.yaml`中设置的是持续学习中的一些最基础的配置,无法仅依靠`default.yaml`直接运行程序。运行代码前,用户需要在`config/`目录下定义已经在LibContinual中实现了的方法的配置。 + +考虑到持续方法有一些基本参数例如`image_sie, epoch`或者`device id`,这样的参数是经常需要改动的。LibContinual支持在命令行中对一些简单的配置进行更改而不需要修改`yaml`文件。同样的,在训练和测试过程中,很多不同的持续学习方法的参数是相同的。为了简洁起见,我们将这些相同的参数包装到了一起,放到了`config/headers`目录下,这样就能够通过导入的方式简洁地编写自定义方法的`yaml`文件。 + +以下是`config/headers`目录下文件的构成。 + +- `data.yaml`:定义了训练所使用的数据的相关配置。 +- `device.yaml`:定义了训练所使用的GPU的相关配置。 +- `model.yaml`:定义了模型训练的相关配置。 +- `optimizer.yaml`:定义了训练所使用的优化器的相关配置。 + +## LibContinual中配置文件的设置 + +以下详细介绍配置文件中每部分代表的信息以及如何编写,以下将以bic方法的配置给出示例。 + +### 数据设置 + ++ `data_root`:数据集存放的路径。 + ++ `image_size`:输入图像的尺寸。 + ++ `pin_momery`:是否使用内存加速读取。 + ++ `augment`:是否使用数据增强。 + ++ `init_cls_num`: 初始类别数量。 + ++ `inc_cls_num`: 增量类别数量。 + ++ `task_num`: 任务数量。 + ++ `works`:数据加载和预处理的工作线程数量。 + + ```yaml + data_root: /data/cifar100 + image_size: 84 + pin_memory: False + augment: True + init_cls_num: 20 + inc_cls_num: 20 + task_num: 5 + works: 8 + ``` + + +### 模型设置 + ++ `backbone`:方法所使用的`backbone`信息。 + + + `name`:使用的backbone的名称,需要与LibContinual中实现的backbone的大小写一致。 + + `kwargs`:`backbone`初始化时用到的参数,必须保持名称与代码中的名称一致。 + + `num_classes`:类别数量。 + + `args`:其他项参数,例如所使用的数据集`dataset`。 + + ```yaml + backbone: + name: resnet18 + kwargs: + num_classes: 100 + args: + dataset: cifar100 + ``` ++ `classifier`:方法所使用的方法信息。 + + + `name`:使用的方法的名称,需要与LibContinual中实现的方法的名称一致。 ++ `kwargs`:方法初始化时用到的参数,必须保持名称与代码中的名称一致。 + + + `feat_dim`:维度设定。 + + ```yaml + classifier: + name: bic + kwargs: + feat_dim: 512 + ``` + +### 训练设置 + ++ `epoch`:训练的`epoch`数。 + ++ `test_epoch`: 测试的`epoch`数。 + ++ `val_per_epoch`: 验证阶段的每一次的`epoch`数。 + ++ `stage2_epoch`: 策略2的`epoch`数。 + ++ `batch_size`: 训练的批次尺寸。 + + ```yaml + epoch: 50 + test_epoch: 5 + val_per_epoch: 5 + stage2_epoch: 100 + batch_size: 128 + ``` + +### 优化器设置 + ++ `optimizer`:训练阶段使用的优化器信息。 + + `name`:优化器名称,当前仅支持`Pytorch`提供的所有优化器。 + + `kwargs`:传入优化器的参数,名称需要与Pytorch优化器所需要的参数名称相同。 + + `other`:当前仅支持单独指定方法中的每一部分所使用的学习率,名称需要与方法中所使用的变量名相同。 + + ```yaml + optimizer: + name: SGD + kwargs: + lr: 0.01 + weight_decay: 2e-4 + momentum: 0.9 + ``` + ++ `lr_scheduler`:训练时使用的学习率调整策略,当前仅支持`Pytorch`提供的所有学习率调整策略。 + + `name`:学习率调整策略名称。 + + `kwargs`:其他`Pytorch`学习率调整策略所需要的参数。 + + ```yaml + lr_scheduler: + name: MultiStepLR + kwargs: + gamma: 0.1 + milestones: [25, 50] + ``` + +### 硬件设置 + ++ `device_ids`:训练可以用到的`gpu`的编号,与`nvidia-smi`命令显示的编号相同。 + ++ `n_gpu`:训练使用并行训练的`gpu`个数,如果仅有`1`个GPU的话,则不适用并行训练。 + ++ `seed`:训练时`numpy`,`torch`,`cuda`使用的种子点。 + ++ `deterministic`:是否开启`torch.backend.cudnn.benchmark`以及`torch.backend.cudnn.deterministic`以及是否使训练随机种子确定。 + + ```yaml + device_ids: 0,1,2,3,4,5,6,7 + n_gpu: 4 + seed: 1993 + deterministic: False + ``` diff --git a/reproduce/api/README.md b/reproduce/api/README.md new file mode 100644 index 0000000000000000000000000000000000000000..57c1ab36d007e09c8d423a6969588014dd02f2a0 --- /dev/null +++ b/reproduce/api/README.md @@ -0,0 +1,31 @@ +# API : Adaptive Plasticity Improvement for Continual Learning [(CVPR'2023)](https://openaccess.thecvf.com/content/CVPR2023/papers/Liang_Adaptive_Plasticity_Improvement_for_Continual_Learning_CVPR_2023_paper.pdf) + +## Abstract + +Many works have tried to solve the catastrophic forgetting (CF) problem in continual learning (lifelong learning). However, pursuing non-forgetting on old tasks may damage the model’s plasticity for new tasks. Although some methods have been proposed to achieve stability-plasticity trade-off, no methods have considered evaluating a model’s plasticity and improving plasticity adaptively for a new task. In this work, we propose a new method, called adaptive plasticity improvement (API), for continual learning. Besides the ability to overcome CF on old tasks, API also tries to evaluate the model’s plasticity and then adaptively improve the model’s plasticity for learning a new task if necessary. Experiments on several real datasets show that API can outperform other state-of-the-art baselines in terms of both accuracy and memory usage. + +## Citation + +```bibtex +@inproceedings{ + liang2023adaptive, + title={Adaptive Plasticity Improvement for Continual Learning}, + author={Liang, Yan-Shuo and Li, Wu-Jun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={7816--7825}, + year={2023} +} +``` + +## How to Reproduce GPM + +- **Step 1 : Run one of these commands** + ```python + python run_trainter.py --config_name api_til-alexnet-cifar100-b5-5-20 + ``` + +## Results + +| Dataset | Scenario | Num of Tasks | Epochs | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: |:-----------: | :----: | :-----------------: | :---------------: | +| CIFAR100 | TIL | 20 | 200 | 81.24 | 81.40 | diff --git a/reproduce/bic/README.md b/reproduce/bic/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c20375cdb718dfc2e0fc1a9d93692d05488244b0 --- /dev/null +++ b/reproduce/bic/README.md @@ -0,0 +1,52 @@ +# Large Scale Incremental Learning [(CVPR'2019)](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wu_Large_Scale_Incremental_Learning_CVPR_2019_paper.pdf) + +## Abstract + + + +Modern machine learning suffers from catastrophic forgetting when learning new classes incrementally. The performance dramatically degrades due to the missing data of old classes. Incremental learning methods have been proposed to retain the knowledge acquired from the old classes, by using knowledge distilling and keeping a few exemplars from the old classes. However, these methods struggle to scale up to a large number of classes. We believe this is because of the combination of two factors: (a) the data imbalance between the old and new classes, and (b) the increasing number of visually similar classes. Distinguishing between an increasing number of visually similar classes is particularly challenging, when the training data is unbalanced. We propose a simple and effective method to address this data imbalance issue. We found that the last fully connected layer has a strong bias towards the new classes, and this bias can be corrected by a linear model. With two bias parameters, our method performs remarkably well on two large datasets: ImageNet (1000 classes) and MS-Celeb1M (10000 classes), outperforming the state-of-the-art algorithms by 11.1% and 13.2% respectively. + + + +![bic](../../resources/imgs/bic.png) + +## Citation + + + +```bibtex +@inproceedings{wu2019large, + title={Large scale incremental learning}, + author={Wu, Yue and Chen, Yinpeng and Wang, Lijuan and Ye, Yuancheng and Liu, Zicheng and Guo, Yandong and Fu, Yun}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)}, + pages={374--382}, + year={2019} +} +``` + +## How to Reproduce bic + +- **Step1: Set the path in `run_trainer.py` with `./config/bic.yaml`** + ```python + config = Config("./config/bic.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + +**Note**: + +- All the result are trained with single gpu. +- Modifications made to `trainer.py` include the implementation of a two-stage training process for `bic`. +- Conditional checks have been introduced within the code to ascertain whether the second stafe of training should be initiated, ensuring that these changes do not impede the functionality of other algorithms. +- `bic` requires splitting the `buffer` and `new data`, so some conditional statements have been added to the data processing section of `trainer.py` as well. + +## Results on cifar100 dataset with 180 episodes + +| Method | 20 | 40 | 60 | 80 | 100 | +| -------------- | ---- | ----- | ----- | ----- | ----- | +| original bic | 0.84 | 0.747 | 0.679 | 0.613 | 0.567 | +| Ours (stage 1) | 0.89 | 0.705 | 0.650 | 0.578 | 0.514 | +| Ours (stage 2) | 0.89 | 0.725 | 0.687 | 0.628 | 0.578 | + diff --git a/reproduce/cl_lora/README.md b/reproduce/cl_lora/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6282803dc86eeec3685ab1eb060b10c8ee371856 --- /dev/null +++ b/reproduce/cl_lora/README.md @@ -0,0 +1,34 @@ +# CL-LoRA: Continual Low-Rank Adaptation for Rehearsal-Free Class-Incremental Learning [(CVPR'2025)](https://arxiv.org/abs/2505.24816) + +## Abstract +Class-Incremental Learning (CIL) aims to learn new classes sequentially while retaining the knowledge of previously learned classes. Recently, pre-trained models (PTMs) combined with parameter-efficient fine-tuning (PEFT) have shown remarkable performance in rehearsal-free CIL without requiring exemplars from previous tasks. However, existing adapter-based methods, which incorporate lightweight learnable modules into PTMs for CIL, create new adapters for each new task, leading to both parameter redundancy and failure to leverage shared knowledge across tasks. In this work, we propose ContinuaL Low-Rank Adaptation (CL-LoRA), which introduces a novel dual-adapter architecture combining task-shared adapters to learn cross-task knowledge and task-specific adapters to capture unique features of each new task. Specifically, the shared adapters utilize random orthogonal matrices and leverage knowledge distillation with gradient reassignment to preserve essential shared knowledge. In addition, we introduce learnable block-wise weights for task-specific adapters, which mitigate inter-task interference while maintaining the model's plasticity. We demonstrate CL-LoRA consistently achieves promising performance under multiple benchmarks with reduced training and inference computation, establishing a more efficient and scalable paradigm for continual learning with pre-trained models. + +![cl_lora](../../resources/imgs/cl_lora.png) + +## Citation + +```bibtex +@article{He_2025_CVPR, + author = {He, Jiangpeng and Duan, Zhihao and Zhu, Fengqing}, + title = {CL-LoRA: Continual Low-Rank Adaptation for Rehearsal-Free Class-Incremental Learning}, + journal = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, + month = {June}, + year = {2025}, + pages = {30534-30544} +} +``` +## How to Reproduce + +- **Step 1 : Run any of these commands** + ```python + python run_trainer.py --config cl_lora-cifar100-b5-5-20 + python run_trainer.py --config cl_lora-imagenetr-b5-5-40 + ``` +## Results + +* Settings : B{init}-{inc}-{t}, init : init_cls_num, inc : inc_cls_num, t : total_tasks + +| Method | Dataset | Settings | Reproduced Last Acc | Reported Last Acc | Reproduced Avg Acc | Reported Avg Acc | +| :--------: | :-------: | :-------: | :-----------------: | :---------------: | :----------------: | :--------------: | +| CL_LoRA | Cifar100 | B5-5-20 | 84.93 | 85.32 | 90.31 | 91.02 | +| CL_LoRA | ImageNetR | B5-5-40 | 73.27 | 74.51 | 81.40 | 81.58 | \ No newline at end of file diff --git a/reproduce/codaprompt/README.md b/reproduce/codaprompt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b94b3c4fceed109c60c2dd6065225eb80fcda50d --- /dev/null +++ b/reproduce/codaprompt/README.md @@ -0,0 +1,34 @@ +# CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning [(CVPR' 2023)](https://openaccess.thecvf.com/content/CVPR2023/html/Smith_CODA-Prompt_COntinual_Decomposed_Attention-Based_Prompting_for_Rehearsal-Free_Continual_Learning_CVPR_2023_paper.html) + +## Abstract +Computer vision models suffer from a phenomenon known as catastrophic forgetting when learning novel concepts from continuously shifting training data. Typical solutions for this continual learning problem require extensive rehearsal of previously seen data, which increases memory costs and may violate data privacy. Recently, the emergence of large-scale pre-trained vision transformer models has enabled prompting approaches as an alternative to data-rehearsal. These approaches rely on a key-query mechanism to generate prompts and have been found to be highly resistant to catastrophic forgetting in the well-established rehearsal-free continual learning setting. However, the key mechanism of these methods is not trained end-to-end with the task sequence. Our experiments show that this leads to a reduction in their plasticity, hence sacrificing new task accuracy, and inability to benefit from expanded parameter capacity. We instead propose to learn a set of prompt components which are assembled with input-conditioned weights to produce input-conditioned prompts, resulting in a novel attention-based end-to-end key-query scheme. Our experiments show that we outperform the current SOTA method DualPrompt on established benchmarks by as much as 4.5% in average final accuracy. We also outperform the state of art by as much as 4.4% accuracy on a continual learning benchmark which contains both class-incremental and domain-incremental task shifts, corresponding to many practical settings. Our code is available at https://github.com/GT-RIPL/CODA-Prompt + +![codaprompt](../../resources/imgs/codaprompt.png) + +## Citation +```bibtex +@inproceedings{smith2023coda, + title={Coda-prompt: Continual decomposed attention-based prompting for rehearsal-free continual learning}, + author={Smith, James Seale and Karlinsky, Leonid and Gutta, Vyshnavi and Cascante-Bonilla, Paola and Kim, Donghyun and Arbelle, Assaf and Panda, Rameswar and Feris, Rogerio and Kira, Zsolt}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={11909--11919}, + year={2023} +} +``` + +## How to Reproduce CodaPrompt + +- **Step1: Set the path in `run_trainer.py` with `./config/codaprompt.yaml`** + ```python + config = Config("./config/codaprompt.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + +## Results +| Dataset | Backbone |Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------------: |:----------: | :---------: | :-----------------: | :---------------: | +| CIFAR100 | vit_pt_imnet | 10 | 0 | 86.22 | 86.33 | \ No newline at end of file diff --git a/reproduce/dap/README.md b/reproduce/dap/README.md new file mode 100644 index 0000000000000000000000000000000000000000..79f39e0359fdbde63474d6069328056322e360a8 --- /dev/null +++ b/reproduce/dap/README.md @@ -0,0 +1,28 @@ +# Dynamically Anchored Prompting for Task-Imbalanced Continual Learning [(IJCAI' 2024)](https://arxiv.org/abs/2404.14721) + +## Abstract +Existing continual learning literature relies heavily on a strong assumption that tasks arrive with a balanced data stream, which is often unrealistic in real-world applications. In this work, we explore task-imbalanced continual learning (TICL) scenarios where the distribution of task data is non-uniform across the whole learning process. We find that imbalanced tasks significantly challenge the capability of models to control the trade-off between stability and plasticity from the perspective of recent prompt-based continual learning methods. On top of the above finding, we propose Dynamically Anchored Prompting (DAP), a prompt-based method that only maintains a single general prompt to adapt to the shifts within a task stream dynamically. This general prompt is regularized in the prompt space with two specifically designed prompt anchors, called boosting anchor and stabilizing anchor, to balance stability and plasticity in TICL. Remarkably, DAP achieves this balance by only storing a prompt across the data stream, therefore offering a substantial advantage in rehearsal-free CL. Extensive experiments demonstrate that the proposed DAP results in 4.5% to 15% absolute improvements over state-of-the-art methods on benchmarks under task-imbalanced settings. Our code is available at https://github.com/chenxing6666/DAP + +![DAP](../../resources/imgs/dap.png) + +## Citation +```bibtex +@inproceedings{10.24963/ijcai.2024/456, + author = {Hong, Chenxing and Jin, Yan and Kang, Zhiqi and Chen, Yizhou and Li, Mengke and Lu, Yang and Wang, Hanzi}, + title = {Dynamically anchored prompting for task-imbalanced continual learning}, + booktitle = {Proceedings of the Thirty-Third International Joint Conference on Artificial Intelligence}, + year = {2025}, +} +``` + +## How to Reproduce DAP + +- **Step1: Run command** + ```python + python run_trainer.py --config dap.yaml + ``` + +## Results +| Dataset | Backbone |Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------------: |:----------: | :---------: | :-----------------: | :---------------: | +| ltCIFAR100 | vit_pt_imnet_dap | 10 | 0 | 56.58 | 56.30 | \ No newline at end of file diff --git a/reproduce/der/README.md b/reproduce/der/README.md new file mode 100644 index 0000000000000000000000000000000000000000..76eb04f9a6cd2c9ebce0b9978fdda51607bb7016 --- /dev/null +++ b/reproduce/der/README.md @@ -0,0 +1,37 @@ +# DER: Dynamically Expandable Representation for Class Incremental Learning [(CVPR'2021)](https://openaccess.thecvf.com/content/CVPR2021/html/Yan_DER_Dynamically_Expandable_Representation_for_Class_Incremental_Learning_CVPR_2021_paper.html) +## Abstract + +We address the problem of class incremental learning, which is a core step towards achieving adaptive vision intelligence. In particular, we consider the task setting of incremental learning with limited memory and aim to achieve a better stability-plasticity trade-off. To this end, we propose a novel two-stage learning approach that utilizes a dynamically expandable representation for more effective incremental concept modeling. Specifically, at each incremental step, we freeze the previously learned representation and augment it with additional feature dimensions from a new learnable feature extractor. Moreover, we dynamically expand the representation according to the complexity of novel concepts by introducing a channel-level mask-based pruning strategy. This enables us to integrate new visual concepts with retaining learned knowledge. Furthermore, we introduce an auxiliary loss to encourage the model to learn diverse and discriminate features for novel concepts. We conduct extensive experiments on the three class incremental learning benchmarks and our method consistently outperforms other methods with a large margin. + +![DER](../../resources/imgs/der.gif) + +## Citation + +```bibtex +@inproceedings{Yan2021DER, + author = {Shipeng Yan and Jiangwei Xie and Xuming He}, + title = {{DER:} Dynamically Expandable Representation for Class Incremental Learning}, + booktitle = {Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)}, + pages = {3014--3023}, + year = {2021}, +} +``` + +## How to Reproduce DER + +- **Step1: Set the path in `run_trainer.py` with `./config/der.yaml`** + ```python + config = Config("./config/lucir.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + +## Results and models + +| Method | Backbone | Pretrained | Dataset | Epochs | Split | Precision | +| :----: | :------: | :--------: | :-------: | :----: | :---------: | :-------: | +| DER | Resnet18 | False | CIFAR-100 | 170 | Base0 Inc10 | 64.90% | + + diff --git a/reproduce/dmnsp/README.md b/reproduce/dmnsp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9eb1c7401ecd15ef5c88ec13bb3e266055a87ec9 --- /dev/null +++ b/reproduce/dmnsp/README.md @@ -0,0 +1,24 @@ +# Title : TODO + +## Abstract : TODO +TODO: + +## Citation : TODO +```bibtex +``` + +## How to Reproduce DMNSP + +- **Step 1 : Configure `./config/dmnsp.yaml` + +- **Step 2 : Run command** + ```python + python run_trainter.py --config_name dmnsp + ``` + +## Results + +| Dataset | Backbone | Num of tasks | Reproduced Overall Accuracy | Reported Overall Accuracy | +| :------: | :------: | :----------: | :-------------------------: | :-----------------------: | +| CIFAR100 | CLIP | 10 | 86.92 | 87.59 | +| CIFAR100 | CLIP | 5 | 87.67 | 85.29 | diff --git a/reproduce/dualprompt/README.md b/reproduce/dualprompt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5ce45367f4a907c8fad66b54443f905d643eb5c4 --- /dev/null +++ b/reproduce/dualprompt/README.md @@ -0,0 +1,35 @@ +# DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning [(ECCV' 2022)](https://arxiv.org/abs/2204.04799) + +## Abstract +Continual learning aims to enable a single model to learn a sequence of tasks without catastrophic forgetting. Top-performing methods usually require a rehearsal buffer to store past pristine examples for experience replay, which, however, limits their practical value due to privacy and memory constraints. In this work, we present a simple yet effective framework, DualPrompt, which learns a tiny set of parameters, called prompts, to properly instruct a pre-trained model to learn tasks arriving sequentially without buffering past examples. DualPrompt presents a novel approach to attach complementary prompts to the pre-trained backbone, and then formulates the objective as learning task-invariant and task-specific “instructions”. With extensive experimental validation, DualPrompt consistently sets state-of-the-art performance under the challenging class-incremental setting. In particular, DualPrompt outperforms recent advanced continual learning methods with relatively large buffer sizes. We also introduce a more challenging benchmark, Split ImageNet-R, to help generalize rehearsal-free continual learning research. Source code is available at https://github.com/google-research/l2p + +![dualprompt](../../resources/imgs/dualprompt.png) + +## Citation +```bibtex +@inproceedings{wang2022dualprompt, + title={Dualprompt: Complementary prompting for rehearsal-free continual learning}, + author={Wang, Zifeng and Zhang, Zizhao and Ebrahimi, Sayna and Sun, Ruoxi and Zhang, Han and Lee, Chen-Yu and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and others}, + booktitle={European Conference on Computer Vision}, + pages={631--648}, + year={2022}, + organization={Springer} +} +``` + +## How to Reproduce DualPrompt + +- **Step1: Set the path in `run_trainer.py` with `./config/dualprompt.yaml`** + ```python + config = Config("./config/dualprompt.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + +## Results +| Dataset | Backbone |Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------------: |:----------: | :---------: | :-----------------: | :---------------: | +| CIFAR100 | vit_pt_imnet | 10 | 0 | 83.21 | 83.69 | \ No newline at end of file diff --git a/reproduce/erace,eraml/README.md b/reproduce/erace,eraml/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cf3b5fd433315f73e8476ebd47905cf3e983d3d4 --- /dev/null +++ b/reproduce/erace,eraml/README.md @@ -0,0 +1,60 @@ +# ERACE, ERAML: New Insights on reducing abrupt representation change in online continual learning [(ICLR'2022)](https://arxiv.org/abs/2104.05025) + +## Abstract + +In the online continual learning paradigm, agents must learn from a changing distribution while respecting memory and compute constraints. Experience Replay (ER), where a small subset of past data is stored and replayed alongside new data, has emerged as a simple and effective learning strategy. In this work, we focus on the change in representations of observed data that arises when previously unobserved classes appear in the incoming data stream, and new classes must be distinguished from previous ones. We shed new light on this question by showing that applying ER causes the newly added classes' representations to overlap significantly with the previous classes, leading to highly disruptive parameter updates. + +Based on this empirical analysis, we propose a new method which mitigates this issue by shielding the learned representations from drastic adaptation to accommodate new classes. We show that using an asymmetric update rule pushes new classes to adapt to the older ones (rather than the reverse), which is more effective especially at task boundaries, where much of the forgetting typically occurs. + +## Citation + +```bibtex +@misc{caccia2022new, + title={New Insights on Reducing Abrupt Representation Change in Online Continual Learning}, + author={Lucas Caccia and Rahaf Aljundi and Nader Asadi and Tinne Tuytelaars and Joelle Pineau and Eugene Belilovsky}, + year={2022}, + eprint={2104.05025}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +## How to Reproduce ERACE, ERAML + +- **Step 1 : Configure `./config/erace.yaml` and `./config/eraml.yaml`** + +- **Step 2 : Run command** + ```python + python run_trainter.py --config_name erace + python run_trainter.py --config_name eraml + ``` + +## Results + +### ERACE + +| Dataset | Num of Tasks | Buffer Size | Epochs | Reproduced Accuracy | Reported Accuray | +| :------: | :----------: | :---------: | :----: | :-----------------: | :--------------: | +| CIFAR10 | 5 | 20*10 | 1 | 31.4 | 42.82 | +| CIFAR10 | 5 | 20*10 | 5 | 42.2 | 49.40 | +| CIFAR10 | 5 | 20*10 | 15 | 46.6 | 44.92 | + +| Dataset | Num of Tasks | Buffer Size | Epochs | Reproduced Accuracy | Reported Accuray | +| :------: | :----------: | :---------: | :----: | :-----------------: | :--------------: | +| CIFAR100 | 20 | 20*100 | 1 | 12.10 | 17.46 | +| CIFAR100 | 20 | 20*100 | 5 | 21.20 | 18.26 | +| CIFAR100 | 20 | 20*100 | 15 | 32.20 | 15.78 | + +### ERAML + +| Dataset | Num of Tasks | Buffer Size | Epochs | Reproduced Accuracy | Reported Accuray | +| :------: | :----------: | :---------: | :----: | :-----------------: | :--------------: | +| CIFAR10 | 5 | 20*10 | 1 | 29.6 | 37.48 | +| CIFAR10 | 5 | 20*10 | 5 | 41.2 | 39.92 | +| CIFAR10 | 5 | 20*10 | 15 | 44.2 | 35.58 | + +| Dataset | Num of Tasks | Buffer Size | Epochs | Reproduced Accuracy | Reported Accuray | +| :------: | :----------: | :---------: | :----: | :-----------------: | :--------------: | +| CIFAR100 | 20 | 20*100 | 1 | 7.00 | 10.26 | +| CIFAR100 | 20 | 20*100 | 5 | 19.20 | 14.52 | +| CIFAR100 | 20 | 20*100 | 15 | 25.30 | 12.22 | \ No newline at end of file diff --git a/reproduce/ewc/README.md b/reproduce/ewc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e392abcdc31474d4e3790c7724f997c4a2c0f59d --- /dev/null +++ b/reproduce/ewc/README.md @@ -0,0 +1,37 @@ +# Overcoming catastrophic forgetting in neural networks [(PNAS' 2017)](https://arxiv.org/abs/1612.00796) + +## Abstract +The ability to learn tasks in a sequential fashion is crucial to the development of artificial intelligence. Neural networks are not, in general, capable of this and it has been widely thought that catastrophic forgetting is an inevitable feature of connectionist models. We show that it is possible to overcome this limitation and train networks that can maintain expertise on tasks which they have not experienced for a long time. Our approach remembers old tasks by selectively slowing down learning on the weights important for those tasks. We demonstrate our approach is scalable and effective by solving a set of classification tasks based on the MNIST hand written digit dataset and by learning several Atari 2600 games sequentially. + +![EWC](../../resources/imgs/EWC.png) + + +## Citation +```bibtex +@article{kirkpatrick2017overcoming, + title={Overcoming catastrophic forgetting in neural networks}, + author={Kirkpatrick, James and Pascanu, Razvan and Rabinowitz, Neil and Veness, Joel and Desjardins, Guillaume and Rusu, Andrei A and Milan, Kieran and Quan, John and Ramalho, Tiago and Grabska-Barwinska, Agnieszka and others}, + journal={Proceedings of the national academy of sciences}, + pages={3521--3526}, + year={2017}, +} +``` + +## How to Reproduce EWC + +- **Step1: Set the path in `run_trainer.py` with `./config/ewc.yaml`** + ```python + config = Config("./config/ewc.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + +## Results + + +|backbone | Dataset | buffer_size | batch_size | init_cls | inc_cls | acc| +| --- | --- | --- | --- | --- | --- | --- | +| cifar_resnet32| CIFAR-100 | 2000 | 128 | 50 | 25 | 49.75 | \ No newline at end of file diff --git a/reproduce/gpm/README.md b/reproduce/gpm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d331707da81b4d6c70e7521c20b81c99229b6be --- /dev/null +++ b/reproduce/gpm/README.md @@ -0,0 +1,40 @@ +# GPM : Gradient Projection Memory for Continual Learning [(ICLR'2021)](https://openreview.net/forum?id=3AOj0RCNC2) + +## Abstract + +The ability to learn continually without forgetting the past tasks is a desired attribute for artificial learning systems. Existing approaches to enable such learning in artificial neural networks usually rely on network growth, importance based weight update or replay of old data from the memory. In contrast, we propose a novel approach where a neural network learns new tasks by taking gradient steps in the orthogonal direction to the gradient subspaces deemed important for the past tasks. We find the bases of these subspaces by analyzing network representations (activations) after learning each task with Singular Value Decomposition (SVD) in a single shot manner and store them in the memory as Gradient Projection Memory (GPM). With qualitative and quantitative analyses, we show that such orthogonal gradient descent induces minimum to no interference with the past tasks, thereby mitigates forgetting. We evaluate our algorithm on diverse image classification datasets with short and long sequences of tasks and report better or on-par performance compared to the state-of-the-art approaches. + +## Citation + +```bibtex +@inproceedings{ + saha2021gradient, + title={Gradient Projection Memory for Continual Learning}, + author={Gobinda Saha and Isha Garg and Kaushik Roy}, + booktitle={International Conference on Learning Representations}, + year={2021}, + url={https://openreview.net/forum?id=3AOj0RCNC2} +} +``` + +# Additional Setting + +The original GPM implementation is fixed to task-incremental scenario, we make it support on class-incremental scenario as well. + +## How to Reproduce GPM + +- **Step 1 : Run one of these commands** + ```python + python run_trainter.py --config_name gpm_cil-alexnet-cifar100-b5-5-20 + python run_trainter.py --config_name gpm_cil-alexnet-cifar100-b10-10-10 + python run_trainter.py --config_name gpm_cil-alexnet-cifar100-b20-20-5 + python run_trainter.py --config_name gpm_til-alexnet-cifar100-b10-10-10 + python run_trainter.py --config_name gpm_til-alexnet-cifar100-b20-20-5 + ``` + +## Results + +| Dataset | Scenario | Num of Tasks | Epochs | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: |:-----------: | :----: | :-----------------: | :---------------: | +| CIFAR100 | Til | 10 | 200 | 73.94 | 72.48 | +| CIFAR100 | Til | 20 | 200 | 68.84 | - | diff --git a/reproduce/icarl/README.md b/reproduce/icarl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a822fa0a89c71140cd4ee20e2087bd996b5fa5a0 --- /dev/null +++ b/reproduce/icarl/README.md @@ -0,0 +1,49 @@ +# iCaRL: Incremental Classifier and Representation Learning [(CVPR'2017)](https://arxiv.org/abs/1611.07725) + + + +## Abstract + +A major open problem on the road to artificial intelligence is the development of incrementally learning systems that learn about more and more concepts over time from a stream of data. In this work, we introduce a new training strategy, iCaRL, that allows learning in such a class-incremental way: only the training data for a small number of classes has to be present at the same time and new classes can be added progressively. + +iCaRL learns strong classifiers and a data representation simultaneously. This distinguishes it from earlier works that were fundamentally limited to fixed data representations and therefore incompatible with deep learning architectures. We show by experiments on CIFAR-100 and ImageNet ILSVRC 2012 data that iCaRL can learn many classes incrementally over a long period of time where other strategies quickly fail. + + + +## Citation + +```bibtex +@inproceedings{rebuffi2017icarl, + title={icarl: Incremental classifier and representation learning}, + author={Rebuffi, Sylvestre-Alvise and Kolesnikov, Alexander and Sperl, Georg and Lampert, Christoph H}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)}, + pages={2001--2010}, + year={2017} +} +``` + + + +## How to Reproduce iCaRL + +- **Step1: Set the path in `run_trainer.py` with `./config/icarl.yaml`** + ```python + config = Config("./config/icarl.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + + +## Results on CIFAR100 dataset + +| Dataset | Num of Tasks | Buffer Size | Reproduced Accuracy | +| :------: | :----------: | :---------: | :-----------------: | +| CIFAR100 | 2 | 2000 | 62.4 | +| CIFAR100 | 5 | 2000 | 54.4 | +| CIFAR100 | 10 | 2000 | 46.5 | + + + diff --git a/reproduce/inflora/README.md b/reproduce/inflora/README.md new file mode 100644 index 0000000000000000000000000000000000000000..51419f7e6dde9f5658b84eb851c9fd3a2b824180 --- /dev/null +++ b/reproduce/inflora/README.md @@ -0,0 +1,42 @@ +# InfLoRA: Interference-Free Low-Rank Adaptation for Continual Learning [(CVPR'2024)](https://openaccess.thecvf.com/content/CVPR2024/html/Liang_InfLoRA_Interference-Free_Low-Rank_Adaptation_for_Continual_Learning_CVPR_2024_paper.html) + +## Abstract +Continual learning requires the model to learn multiple tasks sequentially. In continual learning the model should possess the ability to maintain its performance on old tasks (stability) and the ability to adapt to new tasks continuously (plasticity). Recently parameter-efficient fine-tuning (PEFT) which involves freezing a pre-trained model and injecting a small number of learnable parameters to adapt to downstream tasks has gained increasing popularity in continual learning. Although existing continual learning methods based on PEFT have demonstrated superior performance compared to those not based on PEFT most of them do not consider how to eliminate the interference of the new task on the old tasks which inhibits the model from making a good trade-off between stability and plasticity. In this work we propose a new PEFT method called interference-free low-rank adaptation (InfLoRA) for continual learning. InfLoRA injects a small number of parameters to reparameterize the pre-trained weights and shows that fine-tuning these injected parameters is equivalent to fine-tuning the pre-trained weights within a subspace. Furthermore InfLoRA designs this subspace to eliminate the interference of the new task on the old tasks making a good trade-off between stability and plasticity. Experimental results show that InfLoRA outperforms existing state-of-the-art continual learning methods on multiple datasets. + +![InfLoRA](../../resources/imgs/InfLoRA.png) + +## Citation + +```bibtex +@inproceedings{arXiv:2404.00228v3, + title = {InfLoRA: Interference-Free Low-Rank Adaptation for Continual Learning}, + author = {Yan-Shuo Liang and + Wu-Jun Li}, + booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, {CVPR} 2024, Seattle, Washington}, + publisher = {Computer Vision Foundation / {IEEE}}, + year = {2024}, + url = {https://arxiv.org/abs/2404.00228v3} +} +``` + +## Optimizing InfLoRA + +InfLoRA_opt is an optimized version of InfLoRA . It merges the lora module into weight after every task, reducing the computational cost and improving the performance. + +## How to Reproduce InfLoRA + +- **Step1: Set the path in `run_trainer.py` with `./config/InfLoRA.yaml` or `./config/InfLoRA_opt.yaml`** + ```python + config = Config("./config/InfLoRA.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + +## Results + +| Dataset | Backbone | Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: | :----------: | :---------: | :-----------------: | :---------------------------------------: | +| CIFAR100 | SiNet | 10 | 0 | 87.09 | $86.51 \pm 0.73$(the 10th task) | +| CIFAR100 | SiNet | 10 | 0 | 91.37 | $91.70 \pm 0.32$(the average of 10 tasks) | diff --git a/reproduce/inflora_opt/README.md b/reproduce/inflora_opt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6f6fa52cd29951223ef6e2569431b8b2464a696b --- /dev/null +++ b/reproduce/inflora_opt/README.md @@ -0,0 +1,39 @@ +# InfLoRA: Interference-Free Low-Rank Adaptation for Continual Learning [(CVPR'2024)](https://openaccess.thecvf.com/content/CVPR2024/html/Liang_InfLoRA_Interference-Free_Low-Rank_Adaptation_for_Continual_Learning_CVPR_2024_paper.html) + +## Abstract +Continual learning requires the model to learn multiple tasks sequentially. In continual learning the model should possess the ability to maintain its performance on old tasks (stability) and the ability to adapt to new tasks continuously (plasticity). Recently parameter-efficient fine-tuning (PEFT) which involves freezing a pre-trained model and injecting a small number of learnable parameters to adapt to downstream tasks has gained increasing popularity in continual learning. Although existing continual learning methods based on PEFT have demonstrated superior performance compared to those not based on PEFT most of them do not consider how to eliminate the interference of the new task on the old tasks which inhibits the model from making a good trade-off between stability and plasticity. In this work we propose a new PEFT method called interference-free low-rank adaptation (InfLoRA) for continual learning. InfLoRA injects a small number of parameters to reparameterize the pre-trained weights and shows that fine-tuning these injected parameters is equivalent to fine-tuning the pre-trained weights within a subspace. Furthermore InfLoRA designs this subspace to eliminate the interference of the new task on the old tasks making a good trade-off between stability and plasticity. Experimental results show that InfLoRA outperforms existing state-of-the-art continual learning methods on multiple datasets. + +![InfLoRA](../../resources/imgs/InfLoRA.png) + +## Citation + +```bibtex +@inproceedings{liang2024inflora, + title={InfLoRA: Interference-Free Low-Rank Adaptation for Continual Learning}, + author={Liang, Yan-Shuo and Li, Wu-Jun}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={23638--23647}, + year={2024} +} +``` + +## Difference with implemented InfLoRA in this repository + +* Merges the lora modules into pretrained weight after every task, reducing the computational and storing cost. ( Stated in original paper ) +* Implement Classifier Alignment for multiple dataset. + +## How to Reproduce InfLoRA_opt + +- **Step 1 : Configure `./config/InfLoRA_opt.yaml` + +- **Step 2 : Run command** + ```python + python run_trainter.py --config_name InfLoRA_opt + ``` + +## Results + +| Dataset | Backbone | Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: | :----------: | :---------: | :-----------------: | :---------------------------------------: | +| CIFAR100 | SiNet | 10 | 0 | 87.09 | $86.51 \pm 0.73$(the 10th task) | +| CIFAR100 | SiNet | 10 | 0 | 91.37 | $91.70 \pm 0.32$(the average of 10 tasks) | diff --git a/reproduce/l2p/README.md b/reproduce/l2p/README.md new file mode 100644 index 0000000000000000000000000000000000000000..703132236d454aef6b8d458498b8c160594df132 --- /dev/null +++ b/reproduce/l2p/README.md @@ -0,0 +1,34 @@ +# Learning to Prompt for Continual Learning [(CVPR' 2022)](https://arxiv.org/abs/2112.08654) + +## Abstract +The mainstream paradigm behind continual learning has been to adapt the model parameters to non-stationary data distributions, where catastrophic forgetting is the central challenge. Typical methods rely on a rehearsal buffer or known task identity at test time to retrieve learned knowledge and address forgetting, while this work presents a new paradigm for continual learning that aims to train a more succinct memory system without accessing task identity at test time. Our method learns to dynamically prompt (L2P) a pre-trained model to learn tasks sequentially under different task transitions. In our proposed framework, prompts are small learnable parameters, which are maintained in a memory space. The objective is to optimize prompts to instruct the model prediction and explicitly manage task-invariant and task-specific knowledge while maintaining model plasticity. We conduct comprehensive experiments under popular image classification benchmarks with different challenging continual learning settings, where L2P consistently outperforms prior state-ofthe-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer and is directly applicable to challenging taskagnostic continual learning. Source code is available at https://github.com/google-research/l2p + +![l2p](../../resources/imgs/l2p.png) + +## Citation +```bibtex +@inproceedings{wang2022learning, + title={Learning to prompt for continual learning}, + author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={139--149}, + year={2022} +} +``` + +## How to Reproduce L2P + +- **Step1: Set the path in `run_trainer.py` with `./config/l2p.yaml`** + ```python + config = Config("./config/l2p-vit-cifar100-b10-10-10.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + +## Results +| Dataset | Backbone |Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------------: |:----------: | :---------: | :-----------------: | :---------------: | +| CIFAR100 | vit_pt_imnet | 10 | 0 | 83.56 | 83.83 | \ No newline at end of file diff --git a/reproduce/lora_sub_drs/README.md b/reproduce/lora_sub_drs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c5fa476b19b9f5881b81364e1c5550dc842fb4ad --- /dev/null +++ b/reproduce/lora_sub_drs/README.md @@ -0,0 +1,46 @@ +# LoRA Subtraction for Drift-Resistant Space in Exemplar-Free Continual Learning [(CVPR'2025)](https://openaccess.thecvf.com//content/CVPR2025/papers/Liu_LoRA_Subtraction_for_Drift-Resistant_Space_in_Exemplar-Free_Continual_Learning_CVPR_2025_paper.pdf) + +## Abstract +In continual learning (CL), catastrophic forgetting often arises due to feature drift. This challenge is particularly prominent in the exemplar-free continual learning (EFCL) setting, where samples from previous tasks cannot be retained, making it difficult to preserve prior knowledge. To address this issue, some EFCL methods aim to identify feature spaces that minimize the impact on previous tasks while accommodating new ones. However, they rely on static features or outdated statistics stored from old tasks, which prevents them from capturing the dynamic evolution of the feature space in CL, leading to performance degradation over time. In this paper, we introduce the Drift-Resistant Space (DRS), which effectively handles feature drifts without requiring explicit feature modeling or the storage of previous tasks. A novel parameter-efficient fine-tuning approach called Low-Rank Adaptation Subtraction (LoRA ) is proposed to develop the DRS. This method subtracts the LoRA weights of old tasks from the initial pre-trained weight before processing new task data to establish the DRS for model training. Therefore, LoRA enhances stability, improves efficiency, and simplifies implementation. Furthermore, stabilizing feature drifts allows for better plasticity by learning with a triplet loss. Our method consistently achieves state-of-the-art results, especially for long task sequences, across multiple datasets. + +![LoRA_SUB_DRS](https://github.com/scarlet0703/LoRA-Sub-DRS/raw/master/imgs/pipeline.png) +s +## Citation + +```bibtex +@inproceedings{liu2025lora, + title={LoRA Subtraction for Drift-Resistant Space in Exemplar-Free Continual Learning}, + author={Liu, Xuan and Chang, Xiaobin}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2025} +} +``` + +## How to Reproduce LoRA_SUB_DRS + +- **Run command** + ```python + python run_trainer.py --config lora_sub-cifar100-b10-10-10 + python run_trainer.py --config lora_sub-cifar100-b5-5-20 + python run_trainer.py --config lora_sub-imgnr-b20-20-10 + python run_trainer.py --config lora_sub-imgnr-b10-10-20 + ``` + +## Results + +### Last Accuracy +| Dataset | Backbone | Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: | :----------: | :---------: | :-----------------: | :---------------: | +| CIFAR100 | SiNet | 10 | 0 | 89.50 | 89.14 | +| CIFAR100 | SiNet | 20 | 0 | 88.29 | 88.69 | +| Imagenet-R | SiNet | 10 | 0 | 75.05 | 74.74 | +| Imagenet-R | SiNet | 20 | 0 | 73.62 | 74.80 | + +### Overall Average Accuracy +| Dataset | Backbone | Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: | :----------: | :---------: | :-----------------: | :---------------: | +| CIFAR100 | SiNet | 10 | 0 | 92.55 | 93.11 | +| CIFAR100 | SiNet | 20 | 0 | 92.25 | 92.71 | +| Imagenet-R | SiNet | 10 | 0 | 81.16 | 80.35 | +| Imagenet-R | SiNet | 20 | 0 | 80.69 | 81.21 | + diff --git a/reproduce/lucir/README.md b/reproduce/lucir/README.md new file mode 100644 index 0000000000000000000000000000000000000000..45c33e99c520ae5f5c6bd804fbc7640da0dba78b --- /dev/null +++ b/reproduce/lucir/README.md @@ -0,0 +1,37 @@ +# Learning a Unified Classifier Incrementally via Rebalancing [(CVPR'2019)](https://openaccess.thecvf.com/content_CVPR_2019/html/Hou_Learning_a_Unified_Classifier_Incrementally_via_Rebalancing_CVPR_2019_paper.html) + +## Abstract +Conventionally, deep neural networks are trained offline, relying on a large dataset prepared in advance. This paradigm is often challenged in real-world applications, e.g. online services that involve continuous streams of incoming data. Recently, incremental learning receives increasing attention, and is considered as a promising solution to the practical challenges mentioned above. However, it has been observed that incremental learning is subject to a fundamental difficulty – catastrophic forgetting, namely adapting a model to new data often results in severe performance degradation on previous tasks or classes. Our study reveals that the imbalance between previous and new data is a crucial cause to this problem. In this work, we develop a new framework for incrementally learning a unified classifier, i.e. a classifier that treats both old and new classes uniformly. Specifically, we incorporate three components, cosine normalization, less-forget constraint, and inter-class separation, to mitigate the adverse effects of the imbalance. Experiments show that the proposed method can effectively rebalance the training process, thus obtaining superior performance compared to the existing methods. On CIFAR100 and ImageNet, our method can reduce the classification errors by more than 6% and 13% respectively, under the incremental setting of 10 phases. + +![LUCIR](../../resources/imgs/LUCIR.png) + +## Citation +```bibtex +@inproceedings{hou2019learning, + title={Learning a unified classifier incrementally via rebalancing}, + author={Hou, Saihui and Pan, Xinyu and Loy, Chen Change and Wang, Zilei and Lin, Dahua}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)}, + pages={831--839}, + year={2019} +} +``` + +## How to Reproduce LUCIR + +- **Step1: Set the path in `run_trainer.py` with `./config/lucir.yaml`** + ```python + config = Config("./config/lucir.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + +## Results + + +| Dataset | Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :----------: | :---------: | :-----------------: | :---------------: | +| CIFAR100 | 2 | 2000 | 63.00 | 62.41 | +| CIFAR100 | 5 | 2000 | 47.40 | 48.91 | \ No newline at end of file diff --git a/reproduce/lwf/README.md b/reproduce/lwf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..08bc0761e674b09c680504f79229f7da7a51ae1e --- /dev/null +++ b/reproduce/lwf/README.md @@ -0,0 +1,39 @@ +# Learning without Forgetting [(ECCV'2016)](https://ieeexplore.ieee.org/document/8107520) +## Abstract + +When building a unified vision system or gradually adding new apabilities to a system, the usual assumption is that training data for all tasks is always available. However, as the number of tasks grows, storing and retraining on such data becomes infeasible. A new problem arises where we add new capabilities to a Convolutional Neural Network (CNN), but the training data for its existing capabilities are unavailable. We propose our Learning without Forgetting method, which uses only new task data to train the network while preserving the original capabilities. Our method performs favorably compared to commonly used feature extraction and fine-tuning adaption techniques and performs similarly to multitask learning that uses original task data we assume unavailable. A more surprising observation is that Learning without Forgetting may be able to replace fine-tuning with similar old and new task datasets for improved new task performance. + +![LwF](../../resources/imgs/lwf.gif) + +## Citation + +```bibtex +@inproceedings{Li2018LwF, + title = {Learning Without Forgetting}, + author = {Zhizhong Li and + Derek Hoiem}, + booktitle = {Computer Vision European Conference (ECCV)}, + pages = {614--629}, + year = {2016}, +} +``` + +## How to reproduce LWF + +- **Step1: Set the path in `run_trainer.py` with `./config/lwf.yaml`** + ```python + config = Config("./config/lwf.yaml").get_config_dict() + ``` +- **Step2: Run code** + ```python + python run_trainer.py + ``` + +## Results and models + +| Backbone | Pretrained | Dataset | Epochs | Split | Precision | +| :------: | :--------: | :-------: | :----: | :---------: | :-------: | +| Resnet18 | False | CIFAR-100 | 100 | Base0 Inc10 | 43.00% | +| Resnet18 | False | CIFAR-100 | 100 | Base0 Inc5 | 43.90% | + + diff --git a/reproduce/moe_adapter4cl/README.md b/reproduce/moe_adapter4cl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e96693baf7bf728230a4248e4e772d75e064079c --- /dev/null +++ b/reproduce/moe_adapter4cl/README.md @@ -0,0 +1,40 @@ +# Boosting Continual Learning of Vision-Language Models via Mixture-of-Experts Adapters[(CVPR'2024)](https://arxiv.org/abs/2403.11549) + +## Abstract +Continual learning can empower vision-language models to continuously acquire new knowledge, without the need for access to the entire historical dataset. However, mitigating the performance degradation in large-scale models is non-trivial due to (i) parameter shifts throughout lifelong learning and (ii) significant computational burdens associated with full-model tuning. In this work, we present a parameter-efficient continual learning framework to alleviate long-term forgetting in incremental learning with vision-language models. Our approach involves the dynamic expansion of a pre-trained CLIP model, through the integration of Mixture-of-Experts (MoE) adapters in response to new tasks. To preserve the zero-shot recognition capability of vision-language models, we further introduce a Distribution Discriminative Auto-Selector (DDAS) that automatically routes in-distribution and out-of-distribution inputs to the MoE Adapter and the original CLIP, respectively. Through extensive experiments across various settings, our proposed method consistently outperforms previous state-of-the-art approaches while concurrently reducing parameter training burdens by 60%. + +![InfLoRA](../../resources/imgs/moe_adapter4cl.png) + +## Citation +```bibtex +@inproceedings{yu2024boosting, + title={Boosting continual learning of vision-language models via mixture-of-experts adapters}, + author={Yu, Jiazuo and Zhuge, Yunzhi and Zhang, Lu and Hu, Ping and Wang, Dong and Lu, Huchuan and He, You}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={23219--23230}, + year={2024} +} +``` + +## How to Reproduce MoE_Adapter4CL + +- **Step 1 : Configure `./config/moe_adapter4cl.yaml` + +- **Step 2 : Run command** + ```python + python run_trainter.py --config_name moe_adapter4cl + ``` + +## Results + +| Dataset | Backbone | Settings | Reproduced Overall Accuracy | Reported Overall Accuracy | +| :------: | :------: | :----------: | :-------------------------: | :-----------------------: | +| CIFAR100 | CLIP | B10-10-10 | 85.49 | 85.21 | +| CIFAR100 | CLIP | B5-5-20 | 86.51 | 83.72 | +| CIFAR100 | CLIP | B2-2-50 | xx.xx | 83.60 | + +| Dataset | Backbone | Settings | Reproduced Overall Accuracy | Reported Overall Accuracy | +| :----------: | :------: | :-----------: | :-------------------------: | :-----------------------: | +| TinyImageNet | CLIP | B100-5-21 | 80.10 | 79.96 | +| TinyImageNet | CLIP | B100-10-11 | 80.43 | 80.23 | +| TinyImageNet | CLIP | B100-20-6 | 80.88 | 81.12 | \ No newline at end of file diff --git a/reproduce/ocm/README.md b/reproduce/ocm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d316c0e5bafa2d36a6a54f5378b93dc5d8a53b8a --- /dev/null +++ b/reproduce/ocm/README.md @@ -0,0 +1,56 @@ +# Online Continual Learning through Mutual Information Maximization [(ICML'2021)](https://proceedings.mlr.press/v162/guo22g.html) + +## Abstract + +This paper proposes a new online continual learning technique called OCM based on mutual information maximization. It achieves two objectives that are critical in dealing with catastrophic forgetting (CF). (1) It reduces feature bias caused by cross entropy (CE) as CE learns only discriminative features for each task, but these features may not be discriminative for another task. To learn a new task well, the network parameters learned before have to be modified, which causes CF. The new approach encourages the learning of each task to make use of holistic representations or the full features of the task training data. (2) It encourages preservation of the previously learned knowledge when training a new batch of incrementally arriving data. Empirical evaluation shows that OCM substantially outperforms the online CL baselines. For example, for CIFAR10, OCM improves the accuracy of the best baseline by 13.1% from 64.1% (baseline) to 77.2% (OCM). The code is publicly available at https://github.com/gydpku/OCM. + +![OCM](../../resources/imgs/OCM.png) + +## Citation + +```bibtex +@inproceedings{guo2022online, + title={Online continual learning through mutual information maximization}, + author={Guo, Yiduo and Liu, Bing and Zhao, Dongyan}, + booktitle={International Conference on Machine Learning (ICML)}, + pages={8109--8126}, + year={2022} +} +``` + +## How to Reproduce OCM + +- **Step1: Set the path in `run_trainer.py` with `./config/ocm.yaml`** + ```python + config = Config("./config/ocm.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + + +## Notes + +`OCM` is an online continual learning method, so it requires the use of a specific `OnlineBuffer`. Therefore, the `buffer` section in your configuration file should be configured as follows: + +```yaml +buffer: + name: OnlineBuffer + kwargs: + buffer_size: 5000 + batch_size: 64 + input_size: [3, 32, 32] +``` + + + +## Results on CIFAR100 dataset + +| Dataset | Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :----------: | :---------: | :-----------------: | :---------------: | +| CIFAR100 | 10 | 1000 | 28.6 | 28.1 | +| CIFAR100 | 10 | 2000 | 35.7 | 35.0 | +| CIFAR100 | 10 | 5000 | 41.0 | 42.4 | + diff --git a/reproduce/praka/README.md b/reproduce/praka/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aa04daf908d074ffb6a627e8c0c475d4011dfa78 --- /dev/null +++ b/reproduce/praka/README.md @@ -0,0 +1,38 @@ +# Prototype Reminiscence and Augmented Asymmetric Knowledge Aggregation for Non-Exemplar Class-Incremental Learning [(ICCV' 2023)](https://openaccess.thecvf.com/content/ICCV2023/papers/Shi_Prototype_Reminiscence_and_Augmented_Asymmetric_Knowledge_Aggregation_for_Non-Exemplar_Class-Incremental_ICCV_2023_paper.pdf) + +## Abstract +Non-exemplar class-incremental learning(NECIL) requires deep models to maintain existing knowledge while continuously learning new classes without saving old class samples. In NECIL methods, prototypical representations are usually stored, which inject information from former classes to resist catastrophic forgetting in subsequent incremental learning. However, since the model continuously learns new knowledge, the stored prototypical representations cannot correctly model the properties of old classes in the existence of knowledge updates. To address this problem, we propose a novel prototype reminiscence mechanism that incorporates the previous class prototypes with arriving new class features to dynamically reshape old class feature distributions thus preserving the decision boundaries of previous tasks. In addition, to improve the model generalization on both newly arriving classes and old classes, we contribute an augmented asymmetric knowledge aggregation approach, which aggregates the overall knowledge of the current task and extracts the valuable knowledge of the past tasks, on top of self-supervised label augmentation. Experimental results on three benchmarks suggest the superior performance of our approach over the SOTA methods. + +![PRAKA](../../resources/imgs/praka.png) + + +## Citation +If you use this code for your research, please consider citing: + +``` +@inproceedings{shi2023prototype, + title={Prototype Reminiscence and Augmented Asymmetric Knowledge Aggregation for Non-Exemplar Class-Incremental Learning}, + author={Shi, Wuxuan and Ye, Mang}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={1772--1781}, + year={2023} +} +``` + +## How to Reproduce PRAKA + +- **Step1: Set the path in `run_trainer.py` with `./config/praka.yaml`** + ```python + config = Config("./config/praka.yaml").get_config_dict() + ``` +- **Step2: Run command** + + ```python + python run_trainer.py + ``` + +| Dataset | Num of Tasks | Buffer Size | Reproduced Average Accuracy | Reported Average Accuracy | +| :------: | :----------: | :---------: | :-------------------------: | :-----------------------: | +| CIFAR100 | 5 | 0 | 69.54 | 70.02 | +| CIFAR100 | 10 | 0 | 67.12 | 68.86 | + diff --git a/reproduce/ranpac/README.md b/reproduce/ranpac/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ee0975dd597cbf9890e1052da527ea81c12cee95 --- /dev/null +++ b/reproduce/ranpac/README.md @@ -0,0 +1,38 @@ +# RanPAC: Random Projections and Pre-trained Models for Continual Learning [(NeurIPS 2023)](https://arxiv.org/abs/2307.02251) + +## Abstract +Continual learning (CL) aims to incrementally learn different tasks (such as classification) in a non-stationary data stream without forgetting old ones. Most CL works focus on tackling catastrophic forgetting under a learning-from-scratch paradigm. However, with the increasing prominence of foundation models, pre-trained models equipped with informative representations have become available for various downstream requirements. Several CL methods based on pre-trained models have been explored, either utilizing pre-extracted features directly (which makes bridging distribution gaps challenging) or incorporating adaptors (which may be subject to forgetting). In this paper, we propose a concise and effective approach for CL with pre-trained models. Given that forgetting occurs during parameter updating, we contemplate an alternative approach that exploits training-free random projectors and class-prototype accumulation, which thus bypasses the issue. Specifically, we inject a frozen Random Projection layer with nonlinear activation between the pre-trained model's feature representations and output head, which captures interactions between features with expanded dimensionality, providing enhanced linear separability for class-prototype-based CL. We also demonstrate the importance of decorrelating the class-prototypes to reduce the distribution disparity when using pre-trained representations. These techniques prove to be effective and circumvent the problem of forgetting for both class- and domain-incremental continual learning. Compared to previous methods applied to pre-trained ViT-B/16 models, we reduce final error rates by between 20% and 62% on seven class-incremental benchmarks, despite not using any rehearsal memory. We conclude that the full potential of pre-trained models for simple, effective, and fast CL has not hitherto been fully tapped. + +![RanPAC](../../resources/imgs/ranpac.png) + +## Citation + +```bibtex +@misc{mcdonnell2024ranpacrandomprojectionspretrained, + title={RanPAC: Random Projections and Pre-trained Models for Continual Learning}, + author={Mark D. McDonnell and Dong Gong and Amin Parveneh and Ehsan Abbasnejad and Anton van den Hengel}, + year={2024}, + eprint={2307.02251}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2307.02251}, +} +``` + +## How to Reproduce RanPAC + +- **Step1: Set the path in `run_trainer.py` with `./config/ranpac.yaml`** + ```python + config = Config("./config/ranpac.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + +## Results + +| Dataset | Backbone | Num of tasks | Buffer size | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: | :----------: | :---------: | :-----------------: | :-----------------: | +| CIFAR100 | ViT | 10 | 0 | 92.22 | 92.11 | + diff --git a/reproduce/rapf/README.md b/reproduce/rapf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..61250d557ba69dc4e24e2d41e1185bd2d4cd80d6 --- /dev/null +++ b/reproduce/rapf/README.md @@ -0,0 +1,26 @@ +# Class-Incremental Learning with CLIP: Adaptive Representation Adjustment and Parameter Fusion [(ECCV 2024)](https://arxiv.org/abs/2407.14143) + +## Abstract +Class-incremental learning is a challenging problem, where the goal is to train a model that can classify data from an increasing number of classes over time. With the advancement of vision-language pre-trained models such as CLIP, they demonstrate good generalization ability that allows them to excel in class-incremental learning with completely frozen parameters. However, further adaptation to downstream tasks by simply fine-tuning the model leads to severe forgetting. Most existing works with pre-trained models assume that the forgetting of old classes is uniform when the model acquires new knowledge. In this paper, we propose a method named Adaptive Representation Adjustment and Parameter Fusion (RAPF). During training for new data, we measure the influence of new classes on old ones and adjust the representations, using textual features. After training, we employ a decomposed parameter fusion to further mitigate forgetting during adapter module fine-tuning. Experiments on several conventional benchmarks show that our method achieves state-of-the-art results. Our code is available at [this URL](https://github.com/linlany/RAPF). + +![RAPF](../../resources/imgs/rapf.png) + +## How to Reproduce RanPAC + +- **Step1: Set the path in `run_trainer.py` with `./config/trgp.yaml`** + ```python + config = Config("./config/rapf10-10.yaml").get_config_dict() + config = Config("./config/rapf50-10.yaml").get_config_dict() + config = Config("./config/rapf50-5.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + +## Results +| Dataset | Config | Reproduced (avg, last) | Reported (avg, last) | +|------------|--------|------------------------|-----------------------| +| CIFAR-100 | 10-10 | (85.97, 78.85) | (86.87, 79.26) | +| CIFAR-100 | 50-10 | (84.39, 79.58) | (84.73, 79.36) | +| CIFAR-100 | 50-5 | (84.41, 79.77) | (85.03, 79.64) | diff --git a/reproduce/sd_lora/README.md b/reproduce/sd_lora/README.md new file mode 100644 index 0000000000000000000000000000000000000000..63dff117b150e971f2725d3e4e254c27c1136715 --- /dev/null +++ b/reproduce/sd_lora/README.md @@ -0,0 +1,55 @@ +# SD-LoRA: Scalable Decoupled Low-Rank Adaptation for Class Incremental Learning [(ICLR'2025)](https://arxiv.org/abs/2501.13198) + +## Abstract +Continual Learning (CL) with foundation models has recently emerged as a promising paradigm to exploit abundant knowledge acquired during pre-training for tackling sequential tasks. However, existing prompt-based and Low-Rank Adaptation-based (LoRA-based) methods often require expanding a prompt/LoRA pool or retaining samples of previous tasks, which poses significant scalability challenges as the number of tasks grows. To address these limitations, we propose Scalable Decoupled LoRA (SD-LoRA) for class incremental learning, which continually separates the learning of the magnitude and direction of LoRA components without rehearsal. Our empirical and theoretical analysis reveals that SD-LoRA tends to follow a low-loss trajectory and converges to an overlapping low-loss region for all learned tasks, resulting in an excellent stability-plasticity trade-off. Building upon these insights, we introduce two variants of SD-LoRA with further improved parameter efficiency. All parameters of SD-LoRAs can be end-to-end optimized for CL objectives. Meanwhile, they support efficient inference by allowing direct evaluation with the finally trained model, obviating the need for component selection. Extensive experiments across multiple CL benchmarks and foundation models consistently validate the effectiveness of SD-LoRA. + +![sd_lora](../../resources/imgs/sd_lora.png) + +## Citation + +```bibtex +@misc{wu2025sdlorascalabledecoupledlowrank, + title={SD-LoRA: Scalable Decoupled Low-Rank Adaptation for Class Incremental Learning}, + author={Yichen Wu and Hongming Piao and Long-Kai Huang and Renzhen Wang and Wanhua Li and Hanspeter Pfister and Deyu Meng and Kede Ma and Ying Wei}, + year={2025}, + eprint={2501.13198}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2501.13198}, +} +``` +## How to Reproduce + +- **Step 1 : Run any of these commands** + ```python + python run_trainer.py --config sd_lora-vit-cifar100-b10-10-10 + python run_trainer.py --config sd_lora-vit-imagenetr-b10-10-20 + python run_trainer.py --config sd_lora-vit-imagenetr-b20-20-10 + python run_trainer.py --config sd_lora-vit-imagenetr-b40-40-5 + ``` +## Results + +* Settings : B{init}-{inc}-{t}, init : init_cls_num, inc : inc_cls_num, t : total_tasks +* The source code utilizes different hyperparameters—such as the optimizer, learning rate, and number of epochs compared to those specified in the papers. +* We have chosen to adhere to the configurations defined in the source code. + +| Method | Dataset | Settings | Reproduced Last Acc | Reported Last Acc | Reproduced Avg Acc | Reported Avg Acc | +| :--------: | :-------: | :-------: | :-----------------: | :---------------: | :----------------: | :--------------: | +| SD_LoRA | Cifar100 | B10-10-10 | 87.16 | 88.01 | 92.21 | 92.54 | +| SD_LoRA | ImageNetR | B10-10-20 | 76.20 | 75.26 | 81.38 | 80.22 | +| SD_LoRA | ImageNetR | B20-20-10 | 78.35 | 77.34 | 83.16 | 82.04 | +| SD_LoRA | ImageNetR | B40-40-5 | 79.97 | 79.15 | 83.45 | 83.01 | + +| Method | Dataset | Settings | Reproduced Last Acc | Reported Last Acc | Reproduced Avg Acc | Reported Avg Acc | +| :---------: | :-------: | :-------: | :-----------------: | :---------------: | :----------------: | :--------------: | +| SD_LoRA-RR | Cifar100 | B10-10-10 | 87.14 | 87.26 | 92.27 | 92.05 | +| SD_LoRA-RR | ImageNetR | B10-10-20 | 75.92 | 74.05 | 81.43 | 80.65 | +| SD_LoRA-RR | ImageNetR | B20-20-10 | 78.73 | 77.18 | 83.18 | 81.74 | +| SD_LoRA-RR | ImageNetR | B40-40-5 | 80.08 | 79.01 | 83.48 | 82.50 | + +| Method | Dataset | Settings | Reproduced Last Acc | Reported Last Acc | Reproduced Avg Acc | Reported Avg Acc | +| :---------: | :-------: | :-------: | :-----------------: | :---------------: | :----------------: | :--------------: | +| SD_LoRA-KD | Cifar100 | B10-10-10 | 86.75 | 87.09 | 91.92 | 92.01 | +| SD_LoRA-KD | ImageNetR | B10-10-20 | 76.00 | 74.12 | 81.48 | 80.11 | +| SD_LoRA-KD | ImageNetR | B20-20-10 | 78.90 | 77.03 | 83.08 | 81.52 | +| SD_LoRA-KD | ImageNetR | B40-40-5 | 79.87 | 78.85 | 83.20 | 82.47 | diff --git a/reproduce/trgp/README.md b/reproduce/trgp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4f6f4bbb9b0e816821fc0f423a6522626ff867b2 --- /dev/null +++ b/reproduce/trgp/README.md @@ -0,0 +1,33 @@ +# TRGP: Trust Region Gradient Projection for Continual Learning [(ICLR 2022)](https://arxiv.org/abs/2202.02931) + +## Abstract +Catastrophic forgetting is one of the major challenges in continual learning. To address this issue, some existing methods put restrictive constraints on the optimization space of the new task for minimizing the interference to old tasks. However, this may lead to unsatisfactory performance for the new task, especially when the new task is strongly correlated with old tasks. To tackle this challenge, we propose Trust Region Gradient Projection (TRGP) for continual learning to facilitate the forward knowledge transfer based on an efficient characterization of task correlation. Particularly, we introduce a notion of `trust region' to select the most related old tasks for the new task in a layer-wise and single-shot manner, using the norm of gradient projection onto the subspace spanned by task inputs. Then, a scaled weight projection is proposed to cleverly reuse the frozen weights of the selected old tasks in the trust region through a layer-wise scaling matrix. By jointly optimizing the scaling matrices and the model, where the model is updated along the directions orthogonal to the subspaces of old tasks, TRGP can effectively prompt knowledge transfer without forgetting. Extensive experiments show that our approach achieves significant improvement over related state-of-the-art methods. + +## Citation + +```bibtex +@article{lin2022trgp, + title={TRGP: Trust Region Gradient Projection for Continual Learning}, + author={Lin, Sen and Yang, Li and Fan, Deliang and Zhang, Junshan}, + journal={arXiv preprint arXiv:2202.02931}, + year={2022} +} +``` + +## How to Reproduce + +- **Step1: Set the path in `run_trainer.py` with `./config/trgp.yaml`** + ```python + config = Config("./config/trgp.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + +## Results + +| Dataset | Backbone | Num of tasks | Buffer size | Epochs | Reproduced Accuracy | Reported Accuracy | +| :------: | :------: | :----------: | :---------: | :----: | :-----------------: | :-----------------: | +| CIFAR100 | AlexNet | 10 | 0 | 200 | 77.77 | 74.49 | + diff --git a/reproduce/wa/README.md b/reproduce/wa/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc996e8bdcf1fe32ce3ea2921c2e499320af6190 --- /dev/null +++ b/reproduce/wa/README.md @@ -0,0 +1,40 @@ +# Maintaining Discrimination and Fairness in Class Incremental Learning [(CVPR'2020)](https://arxiv.org/abs/1911.07053) + +## Abstract + +Deep neural networks (DNNs) have been applied in class incremental learning, which aims to solve common real-world problems of learning new classes continually. One drawback of standard DNNs is that they are prone to catastrophic forgetting. Knowledge distillation (KD) is a commonly used technique to alleviate this problem. In this paper, we demonstrate it can indeed help the model to output more discriminative results within old classes. However, it cannot alleviate the problem that the model tends to classify objects into new classes, causing the positive effect of KD to be hidden and limited. We observed that an important factor causing catastrophic forgetting is that the weights in the last fully connected (FC) layer are highly biased in class incremental learning. In this paper, we propose a simple and effective solution motivated by the aforementioned observations to address catastrophic forgetting. Firstly, we utilize KD to maintain the discrimination within old classes. Then, to further maintain the fairness between old classes and new classes, we propose Weight Aligning (WA) that corrects the biased weights in the FC layer after normal training process. Unlike previous work, WA does not require any extra parameters or a validation set in advance, as it utilizes the information provided by the biased weights themselves. The proposed method is evaluated on ImageNet-1000, ImageNet-100, and CIFAR-100 under various settings. Experimental results show that the proposed method can effectively alleviate catastrophic forgetting and significantly outperform state-of-the-art methods. + +![WA](../../resources/imgs/wa.png) + + +## Citation + +```bibtex +@inproceedings{zhao2020maintaining, + title={Maintaining discrimination and fairness in class incremental learning}, + author={Zhao, Bowen and Xiao, Xi and Gan, Guojun and Zhang, Bin and Xia, Shu-Tao}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)}, + pages={13208--13217}, + year={2020} +} +``` + +## How to Reproduce WA + +- **Step1: Set the path in `run_trainer.py` with `./config/wa.yaml`** + ```python + config = Config("./config/wa.yaml").get_config_dict() + ``` +- **Step2: Run command** + ```python + python run_trainer.py + ``` + + +## Results on CIFAR-100 dataset + +| Arch | Input Size | Batch Size | Buffer Size | Epochs | Task Number | Average ACC | +| :------: | :--------: | :--------: | :---------: | :----: | :---------: | :---------: | +| resnet32 | 32x32 | 128 | 2000 | 250 | 5 | 47.2% | + +Note: The paper of WA only used a buffer size of 2000 for the experiment. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b23c78767fa45e407c428ce4e8da61c94519b0f6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +diffdist==0.1 +numpy==1.21.6 +Pillow==10.2.0 +scikit_learn==1.0.2 +torch==2.0.1 +torchvision==0.15.2 +pandas==1.3.5 +tqdm==4.64.1 +timm==0.6.7 +PyYAML==6.0.2 +regex==2024.11.6 +ftfy==6.3.1 +continuum==1.2.7 diff --git a/requirements.yaml b/requirements.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f93065e489832a1b858c3a615ad6a7fcdd27cd26 --- /dev/null +++ b/requirements.yaml @@ -0,0 +1,19 @@ +name: libcontinual-env +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python=3.8 + # Conda 包及其精确版本 + - numpy=1.21.5 + - pandas=1.1.5 + - pillow=9.2.0 + - pyyaml=6.0.1 + - scikit-learn=1.0.2 + - pytorch=1.12.1 + - torchvision=0.13.1 + - tqdm=4.64.1 + + - pip: + - diffdist==0.1 diff --git a/resources/imgs/EWC.png b/resources/imgs/EWC.png new file mode 100644 index 0000000000000000000000000000000000000000..fc56dff2634e8cf80e621c78f53c746597e9a6b5 Binary files /dev/null and b/resources/imgs/EWC.png differ diff --git a/resources/imgs/InfLoRA.png b/resources/imgs/InfLoRA.png new file mode 100644 index 0000000000000000000000000000000000000000..d73a095165d156377674c68628df99f915102097 --- /dev/null +++ b/resources/imgs/InfLoRA.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e4c600f0d4b3f0b843d5a9ca956d2787183f4e1b45103181d38d99950871f844 +size 161334 diff --git a/resources/imgs/LUCIR.png b/resources/imgs/LUCIR.png new file mode 100644 index 0000000000000000000000000000000000000000..13c2d87941540f3492b100767883aea30db85f9d --- /dev/null +++ b/resources/imgs/LUCIR.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af440b730fee22910cada12f9883495c352f6aa11e243f22ba0ba1548ec8bd84 +size 120259 diff --git a/resources/imgs/OCM.png b/resources/imgs/OCM.png new file mode 100644 index 0000000000000000000000000000000000000000..7e9331e36e895881167cd5504af0db42c59dc24b Binary files /dev/null and b/resources/imgs/OCM.png differ diff --git a/resources/imgs/bic.png b/resources/imgs/bic.png new file mode 100644 index 0000000000000000000000000000000000000000..3fc3373cc631d0d30004b2a2846ae297383ac7f3 Binary files /dev/null and b/resources/imgs/bic.png differ diff --git a/resources/imgs/cl_lora.png b/resources/imgs/cl_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..a62f78f57430aac0aa1190fc76becfbd48a20987 --- /dev/null +++ b/resources/imgs/cl_lora.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c179ddaf29d04561e42dc5823c3f516a88f17a174b50f22ab0bca36ced6696fa +size 118988 diff --git a/resources/imgs/codaprompt.png b/resources/imgs/codaprompt.png new file mode 100644 index 0000000000000000000000000000000000000000..95288d02649537e0c05d993e12a9a606b98d1494 --- /dev/null +++ b/resources/imgs/codaprompt.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da48b52fe5c88a2f6f9f509c891cbfb4f4fa29e818e3e3540627cc8a1bf2ab17 +size 122173 diff --git a/resources/imgs/dap.png b/resources/imgs/dap.png new file mode 100644 index 0000000000000000000000000000000000000000..ed856b6d9b69af7bc1d8ad6d8f37fa056ad4fbd8 --- /dev/null +++ b/resources/imgs/dap.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42dee6a7fe0452e351ee1f99d3ad7e0c25829cefb5d96eca7ecabf5911f03349 +size 222135 diff --git a/resources/imgs/der.gif b/resources/imgs/der.gif new file mode 100644 index 0000000000000000000000000000000000000000..df0b3b536fba08d5f640f13835b9996982067c58 Binary files /dev/null and b/resources/imgs/der.gif differ diff --git a/resources/imgs/dualprompt.png b/resources/imgs/dualprompt.png new file mode 100644 index 0000000000000000000000000000000000000000..0b9e7166f29af8d240bb94030a51677f935216b0 Binary files /dev/null and b/resources/imgs/dualprompt.png differ diff --git a/resources/imgs/flowchart.png b/resources/imgs/flowchart.png new file mode 100644 index 0000000000000000000000000000000000000000..d7a320e8e6f08ee9fe18b93c57dfc0439a7b14ab --- /dev/null +++ b/resources/imgs/flowchart.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e309f46bd8a5faba4acaf2e338d6be53183fc5f144644f5a77558c0e41dd7b0a +size 715821 diff --git a/resources/imgs/l2p.png b/resources/imgs/l2p.png new file mode 100644 index 0000000000000000000000000000000000000000..54fec80e80dec18f7514259f466c66cd99c885e4 Binary files /dev/null and b/resources/imgs/l2p.png differ diff --git a/resources/imgs/lwf.gif b/resources/imgs/lwf.gif new file mode 100644 index 0000000000000000000000000000000000000000..45143d61d13868b81e177c791e02312df853cc65 Binary files /dev/null and b/resources/imgs/lwf.gif differ diff --git a/resources/imgs/moe_adapter4cl.png b/resources/imgs/moe_adapter4cl.png new file mode 100644 index 0000000000000000000000000000000000000000..3d1ab3879b48c2bf713729fed96eadd704184e3e --- /dev/null +++ b/resources/imgs/moe_adapter4cl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:594290f346d6ea884c287d3ed0cda1f817aaa5ba4d944cb7cd101d99e78a42eb +size 179749 diff --git a/resources/imgs/praka.png b/resources/imgs/praka.png new file mode 100644 index 0000000000000000000000000000000000000000..c864124c14167de5ea1c16c3b28e1fa61f73bdd4 --- /dev/null +++ b/resources/imgs/praka.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82434099dca9ff341628d64d3dcae033b64808db49268497615d81156cd45e7d +size 3068552 diff --git a/resources/imgs/ranpac.png b/resources/imgs/ranpac.png new file mode 100644 index 0000000000000000000000000000000000000000..9ea84b6e258de8b0e6dfcdb86d36e35954eb8965 Binary files /dev/null and b/resources/imgs/ranpac.png differ diff --git a/resources/imgs/rapf.png b/resources/imgs/rapf.png new file mode 100644 index 0000000000000000000000000000000000000000..634e6735dd41d033d908e4729fd98e2393d005aa --- /dev/null +++ b/resources/imgs/rapf.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52ba49c953623aa08c06219f3989fb78eb86d863461ef30c90e5f517e6bcfb98 +size 114375 diff --git a/resources/imgs/sd_lora.png b/resources/imgs/sd_lora.png new file mode 100644 index 0000000000000000000000000000000000000000..fc1bd3bc4ac85317e3dd88c787fe3f2486d0f3dd Binary files /dev/null and b/resources/imgs/sd_lora.png differ diff --git a/resources/imgs/wa.png b/resources/imgs/wa.png new file mode 100644 index 0000000000000000000000000000000000000000..061d921321f42ee5a7473c9f6941bc01a6c072e4 --- /dev/null +++ b/resources/imgs/wa.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42b0e69deb2c019d88a3a96f5241eb62faf8cc831d24d80aa268d64e6fa400d5 +size 225654 diff --git a/run_trainer.py b/run_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..bf27aacb92b5c24f8cf4ae27b373166642d69943 --- /dev/null +++ b/run_trainer.py @@ -0,0 +1,88 @@ +import sys + +sys.dont_write_bytecode = True + +import os +import re +import glob +import time +import torch +import argparse +import subprocess +import torch.multiprocessing as mp + +from core.config import Config +from core import Trainer + +def main(rank, config): + trainer = Trainer(rank, config) + trainer.train_loop() + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, default=None, help='Name of config file') + parser.add_argument('--seed', type=int, default=-1, help='Seed') + parser.add_argument('--device', type=int, default=-1, help='Device') + args = parser.parse_args() + + if args.config: + args.config = args.config + '.yaml' if not args.config.endswith('.yaml') else args.config + config_files = glob.glob(f'./config/**/{args.config}', recursive=True) + assert len(config_files) == 1, "Config files conflict" + config_path = config_files[0] + config = Config(config_path).get_config_dict() + else: + config = Config("./config/InfLoRA.yaml").get_config_dict() + + if config['device_ids'] == 'auto': + least_utilized_device = 0 + lowest_utilization = float('inf') + + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=index,memory.used,memory.total,utilization.gpu', '--format=csv,noheader,nounits'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + if result.returncode != 0: + raise RuntimeError(f"nvidia-smi error: {result.stderr}") + + gpu_info = result.stdout.strip().split('\n') + gpu_utilization = [] + + for gpu in gpu_info: + match = re.match(r'(\d+),\s*(\d+),\s*(\d+),\s*(\d+)', gpu) + if match: + device_id, mem_used, mem_total, gpu_util = map(int, match.groups()) + # Combine memory usage and GPU utilization to determine the utilization score + utilization_score = gpu_util + (mem_used / mem_total) * 100 + gpu_utilization.append((device_id, utilization_score)) + + # Sort GPUs by utilization score (ascending) and select the least utilized GPUs + gpu_utilization.sort(key=lambda x: x[1]) + config["device_ids"] = [str(gpu[0]) for gpu in gpu_utilization[:config["n_gpu"]]] + + except Exception as e: + config["device_ids"] = range(config["n_gpu"]) + print(f"Error while querying GPUs: {e}, using default device {config['device_ids']}") + + if args.seed > -1: + print(f'Seed : {config["seed"]} -> {args.seed}') + config['seed'] = args.seed + + if args.device > -1: + config['device_ids'] = args.device + + if not isinstance(config['device_ids'], list): + config['device_ids'] = [config['device_ids']] + + print(f'Selected GPUs: {config["device_ids"]}') + + if config["n_gpu"] > 1: + mp.spawn(main, nprocs=config["n_gpu"], args=(config,)) + pass + os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] + else: + main(0, config)