Feature Extraction
Transformers
Safetensors
English
spectre
medical-imaging
ct-scan
3d
vision-transformer
self-supervised-learning
foundation-model
radiology
custom_code
Instructions to use cclaess/SPECTRE-Large with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use cclaess/SPECTRE-Large with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="cclaess/SPECTRE-Large", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("cclaess/SPECTRE-Large", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Initial commit
Browse files- README.md +179 -0
- config.json +32 -0
- configuration_spectre.py +32 -0
- model.safetensors +3 -0
- modeling_spectre.py +38 -0
- spectre/__init__.py +29 -0
- spectre/model.py +234 -0
- spectre/models/__init__.py +53 -0
- spectre/models/eomt.py +230 -0
- spectre/models/layers/__init__.py +11 -0
- spectre/models/layers/attention.py +187 -0
- spectre/models/layers/layernorm.py +23 -0
- spectre/models/layers/patch_embed.py +135 -0
- spectre/models/layers/rotary_pos_embed.py +157 -0
- spectre/models/resnet.py +726 -0
- spectre/models/seomt.py +394 -0
- spectre/models/upsample_anything.py +319 -0
- spectre/models/vision_transformer.py +835 -0
- spectre/models/vision_transformer_features.py +455 -0
- spectre/utils/__init__.py +117 -0
- spectre/utils/_utils.py +49 -0
- spectre/utils/checkpointing.py +238 -0
- spectre/utils/collate.py +120 -0
- spectre/utils/config.py +91 -0
- spectre/utils/dataloader.py +126 -0
- spectre/utils/distributed.py +92 -0
- spectre/utils/lora.py +38 -0
- spectre/utils/masking.py +196 -0
- spectre/utils/modeling.py +550 -0
- spectre/utils/param_groups.py +118 -0
- spectre/utils/scheduler.py +236 -0
README.md
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
📢 [2026-04-10] SPECTRE is now an official baseline for the [**CVPR 2026 Workshop Competition: Foundation Models for General CT Image Diagnosis**](https://www.codabench.org/competitions/12650/)! See `experiments/cvpr26_fm_for_ct_diag_task_1` for scripts and additional details.
|
| 2 |
+
|
| 3 |
+
📢 [2026-02-21] SPECTRE has been accepted for presentation at **CVPR 2026** (Denver, Colorado, USA)!
|
| 4 |
+
|
| 5 |
+
📢 [2026-01-20] [Semantic segmentation](https://github.com/cviviers/nnUNet) code and configurations using the nnUNet framework are now released!
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# SPECTRE 👻👻👻
|
| 9 |
+
|
| 10 |
+
<p align="center">
|
| 11 |
+
<a href="https://pypi.org/project/spectre-fm/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/spectre-fm?style=flat-square&label=version&cacheSeconds=0" /></a>
|
| 12 |
+
<a href="https://pypi.org/project/spectre-fm/"><img alt="Python Versions" src="https://img.shields.io/pypi/pyversions/spectre-fm?style=flat-square&cacheSeconds=0" /></a>
|
| 13 |
+
<a href="https://pypi.org/project/spectre-fm/"><img alt="Downloads per Month" src="https://img.shields.io/pypi/dm/spectre-fm?style=flat-square&label=downloads&cacheSeconds=0" /></a>
|
| 14 |
+
<a href="https://github.com/cclaess/SPECTRE/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/cclaess/SPECTRE?style=flat-square&cacheSeconds=0" /></a>
|
| 15 |
+
<a href="https://huggingface.co/cclaess/SPECTRE"><img alt="Model weights" src="https://img.shields.io/badge/models-Hugging%20Face-yellow?style=flat-square&cacheSeconds=0" /></a>
|
| 16 |
+
<a href="https://arxiv.org/abs/2511.17209"><img alt="Paper" src="https://img.shields.io/badge/paper-arXiv-b31b1b?style=flat-square&cacheSeconds=0" /></a>
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
<p align="center">
|
| 20 |
+
<img src="imgs/method_overview.jpg" alt="SPECTRE architecture and pretraining strategies" width="600"/>
|
| 21 |
+
</p>
|
| 22 |
+
|
| 23 |
+
SPECTRE (**S**elf-Supervised & Cross-Modal **P**r**e**training for **CT** **R**epresentation **E**xtraction) is a **Transformer-based foundation model for 3D Computed Tomography (CT) scans**, trained using **self-supervised learning** (SSL) and **cross-modal vision–language alignment** (VLA). It provides rich and generalizable representations from medical imaging data, which can be fine-tuned for downstream tasks such as segmentation, classification, and anomaly detection.
|
| 24 |
+
|
| 25 |
+
SPECTRE has been trained on a large cohort of **open-source CT scans** of the **human abdomen and thorax**, as well as **paired radiology reports** and **Electronic Health Record data**, enabling it to capture representations that generalize across datasets and clinical settings.
|
| 26 |
+
|
| 27 |
+
This repository provides pretrained SPECTRE models together with tools for fine-tuning and evaluation.
|
| 28 |
+
|
| 29 |
+
## 🧠 Pretrained Models
|
| 30 |
+
The pretrained SPECTRE model can easily be imported as follows:
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
config = MODEL_CONFIGS['spectre-large-pretrained']
|
| 37 |
+
model = SpectreImageFeatureExtractor.from_config(config)
|
| 38 |
+
model.eval()
|
| 39 |
+
|
| 40 |
+
# Dummy input: (batch, crops, channels, height, width, depth)
|
| 41 |
+
# For a (3 x 3 x 4) grid of (128 x 128 x 64) CT patches -> Total scan size (384 x 384 x 256)
|
| 42 |
+
x = torch.randn(1, 1, 384, 384, 256)
|
| 43 |
+
B, C, H, W, D = x.shape
|
| 44 |
+
|
| 45 |
+
patch_size = (128, 128, 64)
|
| 46 |
+
pH, pW, pD = patch_size
|
| 47 |
+
|
| 48 |
+
x = x.view(
|
| 49 |
+
B, C,
|
| 50 |
+
H // pH, pH,
|
| 51 |
+
W // pW, pW,
|
| 52 |
+
D // pD, pD,
|
| 53 |
+
).permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(B, -1, C, pH, pW, pD)
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
features = model(
|
| 57 |
+
x,
|
| 58 |
+
grid_size=(
|
| 59 |
+
H // pH,
|
| 60 |
+
W // pW,
|
| 61 |
+
D // pD,
|
| 62 |
+
),
|
| 63 |
+
)
|
| 64 |
+
print("Features shape:", features.shape)
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Alternatively, you can download the weights of the separate components through HuggingFace using the following links:
|
| 68 |
+
|
| 69 |
+
| Architecture | Input Modality | Pretraining Objective | Model Weights |
|
| 70 |
+
|---------------------------|--------------------|-------------------------|-----------------------------------------------------------------------------------------------------------------------------|
|
| 71 |
+
| SPECTRE-ViT-Local | CT crops | SSL | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128_no_vla.pt?download=true) |
|
| 72 |
+
| SPECTRE-ViT-Local | CT crops | SSL + VLA | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128.pt?download=true) |
|
| 73 |
+
| SPECTRE-ViT-Global | Embedded CT crops | VLA | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_combiner_feature_vit_large.pt?download=true) |
|
| 74 |
+
| Qwen3-Embedding-0.6B LoRA | Text (radiology) | VLA | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_qwen3_embedding_0.6B_lora.pt?download=true) |
|
| 75 |
+
|
| 76 |
+
## 🩻 Segmentation (nnUNet)
|
| 77 |
+
|
| 78 |
+
If you're looking for a nnUNet-based segmentation pipeline that uses SPECTRE as the backbone, see: https://github.com/cviviers/nnUNet
|
| 79 |
+
|
| 80 |
+
## 📂 Repository Contents
|
| 81 |
+
|
| 82 |
+
This repository is organized as follows:
|
| 83 |
+
|
| 84 |
+
- 🚀 **`src/spectre/`** – Contains the core package, including:
|
| 85 |
+
- Pretraining methods
|
| 86 |
+
- Model architectures
|
| 87 |
+
- Data handling and transformations
|
| 88 |
+
|
| 89 |
+
- 🛠️ **`src/spectre/configs/`** – Stores configuration files for different training settings.
|
| 90 |
+
|
| 91 |
+
- 🔬 **`experiments/`** – Includes Python scripts for running various pretraining and downstream experiments.
|
| 92 |
+
|
| 93 |
+
- 🐳 **`Dockerfile`** – Defines the environment for running a local version of SPECTRE inside a container.
|
| 94 |
+
|
| 95 |
+
## ⚙️ Setting Up the Environment
|
| 96 |
+
|
| 97 |
+
To get up and running with SPECTRE, install the base package with pip:
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
pip install spectre-fm
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
This installs only the runtime dependencies needed to load and run the pretrained models.
|
| 104 |
+
|
| 105 |
+
If you want to fine-tune or pretrain SPECTRE, install the matching extra:
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
pip install "spectre-fm[training]"
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
If you only need the evaluation stack, install:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
pip install "spectre-fm[eval]"
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
If training on GDS-enabled systems is required, install the CUDA 12 specific extra:
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
pip install "spectre-fm[gds-cuda12]" # with training stack: "spectre-fm[training,gds-cuda12]"
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
**Note that** `gds-cuda12` is only compatible with CUDA 12.x environments.
|
| 124 |
+
|
| 125 |
+
To install everything at once, use:
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
pip install "spectre-fm[all]"
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
or install the latest updates directly from GitHub:
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
pip install git+https://github.com/cclaess/SPECTRE.git
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## 🐳 Building and Using Docker
|
| 138 |
+
|
| 139 |
+
To facilitate deployment and reproducibility, SPECTRE can be run using **Docker**. This allows you to set up a fully functional environment without manually installing dependencies using your own local copy of spectre.
|
| 140 |
+
|
| 141 |
+
### **Building the Docker Image**
|
| 142 |
+
First, ensure you have **Docker** installed. Then, clone and navigate to the repository to build the image:
|
| 143 |
+
```bash
|
| 144 |
+
git clone https://github.com/cclaess/SPECTRE
|
| 145 |
+
cd SPECTRE
|
| 146 |
+
docker build -t spectre-fm .
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### **Running Experiments Inside Docker**
|
| 150 |
+
Once the image is built, you can start a container and execute scripts inside it. For example, to run a DINO pretraining experiment:
|
| 151 |
+
```bash
|
| 152 |
+
docker run --gpus all --rm -v "$(pwd):/mnt" spectre-fm python3 experiments/pretraining/pretrain_dino.py --config_file spectre/configs/dino_default.yaml --output_dir /mnt/outputs/pretraining/dino/
|
| 153 |
+
```
|
| 154 |
+
- `--gpus all` enables GPU acceleration if available.
|
| 155 |
+
- `--rm` removes the container after execution.
|
| 156 |
+
- `-v $(pwd):/mnt` mounts the current directory inside the container.
|
| 157 |
+
|
| 158 |
+
## ⚖️ License
|
| 159 |
+
- **Code: MIT** — see `LICENSE` (permissive; commercial use permitted).
|
| 160 |
+
- **Pretrained model weights: CC-BY-NC-SA** — non-commercial share-alike. The weights and any derivative models that include these weights are NOT cleared for commercial use. See `LICENSE_MODELS` for details and the precise license text.
|
| 161 |
+
|
| 162 |
+
> Note: the pretrained weights are subject to the original dataset licenses. Users intending to use SPECTRE in commercial settings should verify dataset and model licensing and obtain any required permissions.
|
| 163 |
+
|
| 164 |
+
## 📜 Citation
|
| 165 |
+
If you use SPECTRE in your research or wish to cite it, please use the following BibTeX entry of our [preprint](https://arxiv.org/abs/2511.17209):
|
| 166 |
+
```
|
| 167 |
+
@misc{claessens_scaling_2025,
|
| 168 |
+
title = {Scaling {Self}-{Supervised} and {Cross}-{Modal} {Pretraining} for {Volumetric} {CT} {Transformers}},
|
| 169 |
+
url = {http://arxiv.org/abs/2511.17209},
|
| 170 |
+
doi = {10.48550/arXiv.2511.17209},
|
| 171 |
+
author = {Claessens, Cris and Viviers, Christiaan and D'Amicantonio, Giacomo and Bondarev, Egor and Sommen, Fons van der},
|
| 172 |
+
year={2025},
|
| 173 |
+
}
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## 🤝 Acknowledgements
|
| 177 |
+
This project builds upon prior work in self-supervised learning, medical imaging, and transformer-based representation learning. We especially acknowledge [**MONAI**](https://project-monai.github.io/) for their awesome framework and the [**timm**](https://timm.fast.ai/) & [**lightly**](https://docs.lightly.ai/self-supervised-learning/) Python libraries for providing 2D PyTorch models (timm) and object-oriented self-supervised learning methods (lightly), from which we adapted parts of the code for 3D.
|
| 178 |
+
|
| 179 |
+
[](https://star-history.com/#cclaess/SPECTRE&Date)
|
config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"SpectreModel"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_spectre.SpectreConfig",
|
| 7 |
+
"AutoModel": "modeling_spectre.SpectreModel"
|
| 8 |
+
},
|
| 9 |
+
"backbone_kwargs": {
|
| 10 |
+
"global_pool": "",
|
| 11 |
+
"init_values": 1.0,
|
| 12 |
+
"num_classes": 0,
|
| 13 |
+
"pos_embed": "rope",
|
| 14 |
+
"rope_kwargs": {
|
| 15 |
+
"base": 1000.0
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"backbone_name": "vit_large_patch16_128",
|
| 19 |
+
"dtype": "float32",
|
| 20 |
+
"feature_combiner_kwargs": {
|
| 21 |
+
"global_pool": "",
|
| 22 |
+
"init_values": 1.0,
|
| 23 |
+
"num_classes": 0,
|
| 24 |
+
"pos_embed": "rope",
|
| 25 |
+
"rope_kwargs": {
|
| 26 |
+
"base": 100.0
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"feature_combiner_name": "feat_vit_large",
|
| 30 |
+
"model_type": "spectre",
|
| 31 |
+
"transformers_version": "5.3.0"
|
| 32 |
+
}
|
configuration_spectre.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SpectreConfig(PretrainedConfig):
|
| 5 |
+
model_type = "spectre"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
backbone_name="vit_large_patch16_128",
|
| 10 |
+
backbone_kwargs={
|
| 11 |
+
"num_classes": 0,
|
| 12 |
+
"global_pool": '',
|
| 13 |
+
"pos_embed": "rope",
|
| 14 |
+
"rope_kwargs": {"base": 1000.0},
|
| 15 |
+
"init_values": 1.0,
|
| 16 |
+
},
|
| 17 |
+
feature_combiner_name="feat_vit_large",
|
| 18 |
+
feature_combiner_kwargs={
|
| 19 |
+
"num_classes": 0,
|
| 20 |
+
"global_pool": "",
|
| 21 |
+
"pos_embed": "rope",
|
| 22 |
+
"rope_kwargs": {"base": 100.0},
|
| 23 |
+
"init_values": 1.0,
|
| 24 |
+
},
|
| 25 |
+
**kwargs,
|
| 26 |
+
):
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
|
| 29 |
+
self.backbone_name = backbone_name
|
| 30 |
+
self.backbone_kwargs = backbone_kwargs or {}
|
| 31 |
+
self.feature_combiner_name = feature_combiner_name
|
| 32 |
+
self.feature_combiner_kwargs = feature_combiner_kwargs or {}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c90810e90d833cfff54880997fad4d964963fef4dd6824789f44a5bed063cb61
|
| 3 |
+
size 1587720248
|
modeling_spectre.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import PreTrainedModel
|
| 3 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 4 |
+
|
| 5 |
+
from spectre.model import SpectreImageFeatureExtractor
|
| 6 |
+
from configuration_spectre import SpectreConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SpectreModel(PreTrainedModel):
|
| 10 |
+
config_class = SpectreConfig
|
| 11 |
+
base_model_prefix = "spectre"
|
| 12 |
+
main_input_name = "pixel_values"
|
| 13 |
+
|
| 14 |
+
def __init__(self, config):
|
| 15 |
+
super().__init__(config)
|
| 16 |
+
|
| 17 |
+
self.model = SpectreImageFeatureExtractor(
|
| 18 |
+
backbone_name=config.backbone_name,
|
| 19 |
+
backbone_kwargs=config.backbone_kwargs,
|
| 20 |
+
feature_combiner_name=config.feature_combiner_name,
|
| 21 |
+
feature_combiner_kwargs=config.feature_combiner_kwargs,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.post_init()
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
pixel_values: torch.Tensor,
|
| 29 |
+
grid_size=None,
|
| 30 |
+
return_dict=True,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
outputs = self.model(pixel_values, grid_size=grid_size)
|
| 34 |
+
|
| 35 |
+
if not return_dict:
|
| 36 |
+
return (outputs,)
|
| 37 |
+
|
| 38 |
+
return BaseModelOutput(last_hidden_state=outputs)
|
spectre/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Top-level package for spectre.
|
| 2 |
+
|
| 3 |
+
Expose a small, stable public API here so users can do:
|
| 4 |
+
|
| 5 |
+
from spectre import SpectreImageFeatureExtractor, models
|
| 6 |
+
|
| 7 |
+
Keep implementations in subpackages; this file only re-exports the most
|
| 8 |
+
important symbols and subpackages for convenience.
|
| 9 |
+
"""
|
| 10 |
+
from .model import SpectreImageFeatureExtractor, MODEL_CONFIGS
|
| 11 |
+
from . import models
|
| 12 |
+
from . import utils
|
| 13 |
+
|
| 14 |
+
__version__ = "0.1.0"
|
| 15 |
+
__author__ = "Cris Claessens"
|
| 16 |
+
__email__ = "c.h.b.claessens@tue.nl"
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"SpectreImageFeatureExtractor",
|
| 20 |
+
"MODEL_CONFIGS",
|
| 21 |
+
"models",
|
| 22 |
+
"data",
|
| 23 |
+
"transforms",
|
| 24 |
+
"ssl",
|
| 25 |
+
"utils",
|
| 26 |
+
"__version__",
|
| 27 |
+
"__author__",
|
| 28 |
+
"__email__",
|
| 29 |
+
]
|
spectre/model.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
MODEL_CONFIGS = {
|
| 6 |
+
"spectre-small": {
|
| 7 |
+
"name": "spectre-small",
|
| 8 |
+
"backbone": "vit_small_patch16_128",
|
| 9 |
+
"backbone_checkpoint_path_or_url": None,
|
| 10 |
+
"backbone_kwargs": {
|
| 11 |
+
"num_classes": 0,
|
| 12 |
+
"global_pool": '',
|
| 13 |
+
"pos_embed": "rope",
|
| 14 |
+
"rope_kwargs": {"base": 1000.0},
|
| 15 |
+
"init_values": 1.0,
|
| 16 |
+
},
|
| 17 |
+
"feature_combiner": "feat_vit_small",
|
| 18 |
+
"feature_combiner_checkpoint_path_or_url": None,
|
| 19 |
+
"feature_combiner_kwargs": {
|
| 20 |
+
"num_classes": 0,
|
| 21 |
+
"global_pool": '',
|
| 22 |
+
"pos_embed": "rope",
|
| 23 |
+
"rope_kwargs": {"base": 100.0},
|
| 24 |
+
"init_values": 1.0,
|
| 25 |
+
},
|
| 26 |
+
"description": "SPECTRE model with ViT-Small backbone and feature combiner.",
|
| 27 |
+
}, # Pretrained/Distilled checkpoints will be added later
|
| 28 |
+
"spectre-base": {
|
| 29 |
+
"name": "spectre-base",
|
| 30 |
+
"backbone": "vit_base_patch16_128",
|
| 31 |
+
"backbone_checkpoint_path_or_url": None,
|
| 32 |
+
"backbone_kwargs": {
|
| 33 |
+
"num_classes": 0,
|
| 34 |
+
"global_pool": '',
|
| 35 |
+
"pos_embed": "rope",
|
| 36 |
+
"rope_kwargs": {"base": 1000.0},
|
| 37 |
+
"init_values": 1.0,
|
| 38 |
+
},
|
| 39 |
+
"feature_combiner": "feat_vit_base",
|
| 40 |
+
"feature_combiner_checkpoint_path_or_url": None,
|
| 41 |
+
"feature_combiner_kwargs": {
|
| 42 |
+
"num_classes": 0,
|
| 43 |
+
"global_pool": '',
|
| 44 |
+
"pos_embed": "rope",
|
| 45 |
+
"rope_kwargs": {"base": 100.0},
|
| 46 |
+
"init_values": 1.0,
|
| 47 |
+
},
|
| 48 |
+
"description": "SPECTRE model with ViT-Base backbone and feature combiner.",
|
| 49 |
+
}, # Pretrained/Distilled checkpoints will be added later
|
| 50 |
+
"spectre-large": {
|
| 51 |
+
"name": "spectre-large",
|
| 52 |
+
"backbone": "vit_large_patch16_128",
|
| 53 |
+
"backbone_checkpoint_path_or_url": None,
|
| 54 |
+
"backbone_kwargs": {
|
| 55 |
+
"num_classes": 0,
|
| 56 |
+
"global_pool": '',
|
| 57 |
+
"pos_embed": "rope",
|
| 58 |
+
"rope_kwargs": {"base": 1000.0},
|
| 59 |
+
"init_values": 1.0,
|
| 60 |
+
},
|
| 61 |
+
"feature_combiner": "feat_vit_large",
|
| 62 |
+
"feature_combiner_checkpoint_path_or_url": None,
|
| 63 |
+
"feature_combiner_kwargs": {
|
| 64 |
+
"num_classes": 0,
|
| 65 |
+
"global_pool": '',
|
| 66 |
+
"pos_embed": "rope",
|
| 67 |
+
"rope_kwargs": {"base": 100.0},
|
| 68 |
+
"init_values": 1.0,
|
| 69 |
+
},
|
| 70 |
+
"description": "SPECTRE model with ViT-Large backbone and feature combiner.",
|
| 71 |
+
},
|
| 72 |
+
"spectre-large-pretrained": {
|
| 73 |
+
"name": "spectre-large-pretrained",
|
| 74 |
+
"backbone": "vit_large_patch16_128",
|
| 75 |
+
"backbone_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128.pt?download=true",
|
| 76 |
+
"backbone_kwargs": {
|
| 77 |
+
"num_classes": 0,
|
| 78 |
+
"global_pool": '',
|
| 79 |
+
"pos_embed": "rope",
|
| 80 |
+
"rope_kwargs": {"base": 1000.0},
|
| 81 |
+
"init_values": 1.0,
|
| 82 |
+
},
|
| 83 |
+
"feature_combiner": "feat_vit_large",
|
| 84 |
+
"feature_combiner_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_combiner_feature_vit_large.pt?download=true",
|
| 85 |
+
"feature_combiner_kwargs": {
|
| 86 |
+
"num_classes": 0,
|
| 87 |
+
"global_pool": '',
|
| 88 |
+
"pos_embed": "rope",
|
| 89 |
+
"rope_kwargs": {"base": 100.0},
|
| 90 |
+
"init_values": 1.0,
|
| 91 |
+
},
|
| 92 |
+
"description": "Pretrained SPECTRE model with ViT-Large backbone and feature combiner.",
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class SpectreImageFeatureExtractor(nn.Module):
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
backbone_name: str,
|
| 101 |
+
backbone_kwargs: dict = {},
|
| 102 |
+
backbone_checkpoint_path_or_url: str | None = None,
|
| 103 |
+
feature_combiner_name: str | None = None,
|
| 104 |
+
feature_combiner_kwargs: dict = {},
|
| 105 |
+
feature_combiner_checkpoint_path_or_url: str | None = None,
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.backbone = None
|
| 110 |
+
self.feature_combiner = None
|
| 111 |
+
self._init_backbone(
|
| 112 |
+
backbone_name,
|
| 113 |
+
checkpoint_path_or_url=backbone_checkpoint_path_or_url,
|
| 114 |
+
**backbone_kwargs,
|
| 115 |
+
**kwargs,
|
| 116 |
+
)
|
| 117 |
+
if feature_combiner_name is not None:
|
| 118 |
+
self._init_feature_combiner(
|
| 119 |
+
feature_combiner_name,
|
| 120 |
+
checkpoint_path_or_url=feature_combiner_checkpoint_path_or_url,
|
| 121 |
+
**feature_combiner_kwargs,
|
| 122 |
+
**kwargs,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def _init_backbone(
|
| 126 |
+
self,
|
| 127 |
+
model_name: str,
|
| 128 |
+
checkpoint_path_or_url: str | None = None,
|
| 129 |
+
**kwargs
|
| 130 |
+
):
|
| 131 |
+
backbone_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
|
| 132 |
+
self.backbone = backbone_cls(
|
| 133 |
+
checkpoint_path_or_url=checkpoint_path_or_url,
|
| 134 |
+
**kwargs,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def _init_feature_combiner(
|
| 138 |
+
self,
|
| 139 |
+
model_name: str,
|
| 140 |
+
checkpoint_path_or_url: str | None = None,
|
| 141 |
+
**kwargs,
|
| 142 |
+
):
|
| 143 |
+
if self.backbone.global_pool == '':
|
| 144 |
+
patch_dim = self.backbone.embed_dim * 2 # CLS + AVG pooled tokens
|
| 145 |
+
else:
|
| 146 |
+
patch_dim = self.backbone.embed_dim
|
| 147 |
+
|
| 148 |
+
feature_combiner_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
|
| 149 |
+
self.feature_combiner = feature_combiner_cls(
|
| 150 |
+
patch_dim=patch_dim,
|
| 151 |
+
checkpoint_path_or_url=checkpoint_path_or_url,
|
| 152 |
+
**kwargs,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def extract_backbone_features(
|
| 156 |
+
self,
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
):
|
| 159 |
+
"""
|
| 160 |
+
Extract features from the backbone for a batch of image sets. Input is expected to be of
|
| 161 |
+
shape (B, N, C, H, W, D), where B is the batch size, N is the number of image patches per
|
| 162 |
+
image, C is the number of channels, H is height, W is width, and D is depth.
|
| 163 |
+
The output will be a tensor of extracted features (B, N, T, F) where T is the number of
|
| 164 |
+
tokens and F is the feature dimension.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
x (torch.Tensor): Input tensor of shape (B, N, C, H, W, D)
|
| 168 |
+
Returns:
|
| 169 |
+
torch.Tensor: Extracted features of shape (B, N, T, F)
|
| 170 |
+
"""
|
| 171 |
+
assert x.ndim == 6, "Input tensor must have 6 dimensions: (B, N, C, H, W, D)"
|
| 172 |
+
B, N, C, H, W, D = x.shape
|
| 173 |
+
x = x.view(B * N, C, H, W, D)
|
| 174 |
+
features = self.backbone(x)
|
| 175 |
+
if features.ndim == 2: # only CLS token
|
| 176 |
+
features = features.unsqueeze(1)
|
| 177 |
+
features = features.view(B, N, features.shape[1], -1)
|
| 178 |
+
return features
|
| 179 |
+
|
| 180 |
+
def combine_features(
|
| 181 |
+
self,
|
| 182 |
+
features: torch.Tensor,
|
| 183 |
+
grid_size: tuple[int, int, int],
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Combine features from multiple image patches using the feature combiner.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
features (torch.Tensor): Input features of shape (B, N, T, F)
|
| 190 |
+
grid_size (tuple[int, int, int]): Grid size of the image patches
|
| 191 |
+
Returns:
|
| 192 |
+
torch.Tensor: Combined features of shape (B, T', F')
|
| 193 |
+
"""
|
| 194 |
+
_, N, T, _ = features.shape
|
| 195 |
+
assert features.ndim == 4, "Input features must have 4 dimensions: (B, N, T, F)"
|
| 196 |
+
assert N == grid_size[0] * grid_size[1] * grid_size[2], \
|
| 197 |
+
"Number of patches N must match the product of grid_size dimensions"
|
| 198 |
+
|
| 199 |
+
if T == 1: # only CLS token
|
| 200 |
+
features = features.squeeze(2)
|
| 201 |
+
else:
|
| 202 |
+
# We combine CLS tokens with AVG pooling of other tokens
|
| 203 |
+
features = torch.cat([
|
| 204 |
+
features[:, :, 0, :], # CLS token (B, N, F)
|
| 205 |
+
features[:, :, 1:, :].mean(dim=2) # AVG pooled tokens (B, N, F)
|
| 206 |
+
], dim=-1) # (B, N, 2F)
|
| 207 |
+
features = self.feature_combiner(features, grid_size) # (B, T', F')
|
| 208 |
+
return features
|
| 209 |
+
|
| 210 |
+
def forward(self, x, grid_size: tuple[int, int, int] | None = None):
|
| 211 |
+
features = self.extract_backbone_features(x)
|
| 212 |
+
if self.feature_combiner is not None:
|
| 213 |
+
assert grid_size is not None, \
|
| 214 |
+
"`grid_size` must be provided when using feature combiner"
|
| 215 |
+
features = self.combine_features(features, grid_size)
|
| 216 |
+
return features
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def from_config(
|
| 220 |
+
cls,
|
| 221 |
+
config: dict,
|
| 222 |
+
**kwargs,
|
| 223 |
+
) -> 'SpectreImageFeatureExtractor':
|
| 224 |
+
|
| 225 |
+
model = cls(
|
| 226 |
+
backbone_name=config["backbone"],
|
| 227 |
+
backbone_checkpoint_path_or_url=config.get("backbone_checkpoint_path_or_url", None),
|
| 228 |
+
backbone_kwargs=config.get("backbone_kwargs", {}),
|
| 229 |
+
feature_combiner_name=config.get("feature_combiner", None),
|
| 230 |
+
feature_combiner_checkpoint_path_or_url=config.get("feature_combiner_checkpoint_path_or_url", None),
|
| 231 |
+
feature_combiner_kwargs=config.get("feature_combiner_kwargs", {}),
|
| 232 |
+
**kwargs,
|
| 233 |
+
)
|
| 234 |
+
return model
|
spectre/models/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vision_transformer import (
|
| 2 |
+
VisionTransformer,
|
| 3 |
+
vit_tiny_patch16_128,
|
| 4 |
+
vit_small_patch16_128,
|
| 5 |
+
vit_base_patch16_128,
|
| 6 |
+
vit_base_patch32_128,
|
| 7 |
+
vit_large_patch16_128,
|
| 8 |
+
vit_large_patch32_128,
|
| 9 |
+
)
|
| 10 |
+
from .vision_transformer_features import (
|
| 11 |
+
FeatureVisionTransformer,
|
| 12 |
+
feat_vit_tiny,
|
| 13 |
+
feat_vit_small,
|
| 14 |
+
feat_vit_base,
|
| 15 |
+
feat_vit_large,
|
| 16 |
+
)
|
| 17 |
+
from .resnet import (
|
| 18 |
+
ResNet,
|
| 19 |
+
resnet18,
|
| 20 |
+
resnet34,
|
| 21 |
+
resnet50,
|
| 22 |
+
resnet101,
|
| 23 |
+
resnext50,
|
| 24 |
+
resnext101,
|
| 25 |
+
)
|
| 26 |
+
from .eomt import EoMT
|
| 27 |
+
from .seomt import SEoMT
|
| 28 |
+
from .upsample_anything import UPA
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
'VisionTransformer',
|
| 32 |
+
'vit_tiny_patch16_128',
|
| 33 |
+
'vit_small_patch16_128',
|
| 34 |
+
'vit_base_patch16_128',
|
| 35 |
+
'vit_base_patch32_128',
|
| 36 |
+
'vit_large_patch16_128',
|
| 37 |
+
'vit_large_patch32_128',
|
| 38 |
+
'FeatureVisionTransformer',
|
| 39 |
+
'feat_vit_tiny',
|
| 40 |
+
'feat_vit_small',
|
| 41 |
+
'feat_vit_base',
|
| 42 |
+
'feat_vit_large',
|
| 43 |
+
'ResNet',
|
| 44 |
+
'resnet18',
|
| 45 |
+
'resnet34',
|
| 46 |
+
'resnet50',
|
| 47 |
+
'resnet101',
|
| 48 |
+
'resnext50',
|
| 49 |
+
'resnext101',
|
| 50 |
+
'EoMT',
|
| 51 |
+
'UPA',
|
| 52 |
+
'SEoMT',
|
| 53 |
+
]
|
spectre/models/eomt.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/tue-mps/eomt/
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from spectre.models.layers import LayerNorm3d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ScaleBlock(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
embed_dim: int,
|
| 18 |
+
scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
|
| 19 |
+
conv1_layer: nn.Module = nn.ConvTranspose3d,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.conv1 = conv1_layer(
|
| 24 |
+
embed_dim,
|
| 25 |
+
embed_dim,
|
| 26 |
+
kernel_size=scale_factors,
|
| 27 |
+
stride=scale_factors,
|
| 28 |
+
)
|
| 29 |
+
self.act = nn.GELU()
|
| 30 |
+
self.conv2 = nn.Conv3d(
|
| 31 |
+
embed_dim,
|
| 32 |
+
embed_dim,
|
| 33 |
+
kernel_size=3,
|
| 34 |
+
padding=1,
|
| 35 |
+
groups=embed_dim,
|
| 36 |
+
bias=False,
|
| 37 |
+
)
|
| 38 |
+
self.norm = LayerNorm3d(embed_dim)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
x = self.conv1(x)
|
| 42 |
+
x = self.act(x)
|
| 43 |
+
x = self.conv2(x)
|
| 44 |
+
x = self.norm(x)
|
| 45 |
+
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_upscale_stages(patch_size, min_size=4):
|
| 50 |
+
# Compute how many times to upscale per dimension
|
| 51 |
+
num_stages = []
|
| 52 |
+
for size in patch_size:
|
| 53 |
+
stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
|
| 54 |
+
num_stages.append(stages)
|
| 55 |
+
return num_stages
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class EoMT(nn.Module):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
backbone: "VisionTransformer",
|
| 62 |
+
num_classes: int,
|
| 63 |
+
num_q: int,
|
| 64 |
+
num_blocks=4,
|
| 65 |
+
masked_attn_enabled=True,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.backbone = backbone
|
| 69 |
+
self.num_q = num_q
|
| 70 |
+
self.num_blocks = num_blocks
|
| 71 |
+
self.masked_attn_enabled = masked_attn_enabled
|
| 72 |
+
|
| 73 |
+
self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
|
| 74 |
+
|
| 75 |
+
self.q = nn.Embedding(num_q, self.backbone.embed_dim)
|
| 76 |
+
|
| 77 |
+
self.class_head = nn.Linear(self.backbone.embed_dim, num_classes + 1)
|
| 78 |
+
|
| 79 |
+
self.mask_head = nn.Sequential(
|
| 80 |
+
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 81 |
+
nn.GELU(),
|
| 82 |
+
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 83 |
+
nn.GELU(),
|
| 84 |
+
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
patch_size = self.backbone.patch_embed.patch_size
|
| 88 |
+
num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
|
| 89 |
+
|
| 90 |
+
# Build per-stage scale factors list
|
| 91 |
+
max_stages = max(num_upscale_stages)
|
| 92 |
+
upscale_blocks = []
|
| 93 |
+
for stage_idx in range(max_stages):
|
| 94 |
+
# for each dimension, upscale by 2 only if this dimension still has
|
| 95 |
+
# remaining upscales at this stage
|
| 96 |
+
scale_factors = tuple(
|
| 97 |
+
2 if stage_idx < num_upscale_stages[dim] else 1
|
| 98 |
+
for dim in range(len(patch_size))
|
| 99 |
+
)
|
| 100 |
+
upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
|
| 101 |
+
scale_factors=scale_factors))
|
| 102 |
+
|
| 103 |
+
self.upscale = nn.Sequential(*upscale_blocks)
|
| 104 |
+
|
| 105 |
+
def _predict(self, x: torch.Tensor):
|
| 106 |
+
q = x[:, : self.num_q, :]
|
| 107 |
+
|
| 108 |
+
class_logits = self.class_head(q)
|
| 109 |
+
|
| 110 |
+
x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
|
| 111 |
+
x = x.transpose(1, 2).reshape(
|
| 112 |
+
x.shape[0], -1, *self.backbone.patch_embed.grid_size
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
mask_logits = torch.einsum(
|
| 116 |
+
"bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return mask_logits, class_logits
|
| 120 |
+
|
| 121 |
+
@torch.compiler.disable
|
| 122 |
+
def _disable_attn_mask(self, attn_mask, prob):
|
| 123 |
+
if prob < 1:
|
| 124 |
+
random_queries = (
|
| 125 |
+
torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
|
| 126 |
+
> prob
|
| 127 |
+
)
|
| 128 |
+
attn_mask[
|
| 129 |
+
:, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
|
| 130 |
+
][random_queries] = True
|
| 131 |
+
|
| 132 |
+
return attn_mask
|
| 133 |
+
|
| 134 |
+
def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope=None):
|
| 135 |
+
B, N, C = x.shape
|
| 136 |
+
|
| 137 |
+
q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
|
| 138 |
+
kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
|
| 139 |
+
k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 140 |
+
q, k = module.q_norm(q), module.k_norm(k)
|
| 141 |
+
|
| 142 |
+
if mask is not None:
|
| 143 |
+
mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
|
| 144 |
+
|
| 145 |
+
dropout_p = module.attn_drop.p if self.training else 0.0
|
| 146 |
+
|
| 147 |
+
if rope is not None:
|
| 148 |
+
if isinstance(rope, list):
|
| 149 |
+
rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
|
| 150 |
+
q, k = module.apply_rotary_pos_emb(q, k, rope)
|
| 151 |
+
|
| 152 |
+
if module.fused_attn:
|
| 153 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
|
| 154 |
+
else:
|
| 155 |
+
attn = (q @ k.transpose(-2, -1)) * module.scale
|
| 156 |
+
if mask is not None:
|
| 157 |
+
attn = attn.masked_fill(~mask, float("-inf"))
|
| 158 |
+
attn = F.softmax(attn, dim=-1)
|
| 159 |
+
attn = module.attn_drop(attn)
|
| 160 |
+
x = attn @ v
|
| 161 |
+
|
| 162 |
+
x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
|
| 163 |
+
|
| 164 |
+
return x
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor):
|
| 167 |
+
x = self.backbone.patch_embed(x)
|
| 168 |
+
x, rope = self.backbone._pos_embed(x)
|
| 169 |
+
x = self.backbone.patch_drop(x)
|
| 170 |
+
x = self.backbone.norm_pre(x)
|
| 171 |
+
|
| 172 |
+
attn_mask = None
|
| 173 |
+
mask_logits_per_layer, class_logits_per_layer = [], []
|
| 174 |
+
|
| 175 |
+
for i, block in enumerate(self.backbone.blocks):
|
| 176 |
+
if i == len(self.backbone.blocks) - self.num_blocks:
|
| 177 |
+
x = torch.cat(
|
| 178 |
+
(self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if (
|
| 182 |
+
self.masked_attn_enabled
|
| 183 |
+
and i >= len(self.backbone.blocks) - self.num_blocks
|
| 184 |
+
):
|
| 185 |
+
mask_logits, class_logits = self._predict(self.backbone.norm(x))
|
| 186 |
+
mask_logits_per_layer.append(mask_logits)
|
| 187 |
+
class_logits_per_layer.append(class_logits)
|
| 188 |
+
|
| 189 |
+
attn_mask = torch.ones(
|
| 190 |
+
x.shape[0],
|
| 191 |
+
x.shape[1],
|
| 192 |
+
x.shape[1],
|
| 193 |
+
dtype=torch.bool,
|
| 194 |
+
device=x.device,
|
| 195 |
+
)
|
| 196 |
+
interpolated = F.interpolate(
|
| 197 |
+
mask_logits,
|
| 198 |
+
self.backbone.patch_embed.grid_size,
|
| 199 |
+
mode="trilinear",
|
| 200 |
+
)
|
| 201 |
+
interpolated = interpolated.view(
|
| 202 |
+
interpolated.size(0), interpolated.size(1), -1
|
| 203 |
+
)
|
| 204 |
+
attn_mask[
|
| 205 |
+
:,
|
| 206 |
+
: self.num_q,
|
| 207 |
+
self.num_q + self.backbone.num_prefix_tokens :,
|
| 208 |
+
] = (
|
| 209 |
+
interpolated > 0
|
| 210 |
+
)
|
| 211 |
+
attn_mask = self._disable_attn_mask(
|
| 212 |
+
attn_mask,
|
| 213 |
+
self.attn_mask_probs[
|
| 214 |
+
i - len(self.backbone.blocks) + self.num_blocks
|
| 215 |
+
],
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
x = x + block.drop_path1(
|
| 219 |
+
block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope))
|
| 220 |
+
)
|
| 221 |
+
x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
|
| 222 |
+
|
| 223 |
+
mask_logits, class_logits = self._predict(self.backbone.norm(x))
|
| 224 |
+
mask_logits_per_layer.append(mask_logits)
|
| 225 |
+
class_logits_per_layer.append(class_logits)
|
| 226 |
+
|
| 227 |
+
return (
|
| 228 |
+
mask_logits_per_layer,
|
| 229 |
+
class_logits_per_layer,
|
| 230 |
+
)
|
spectre/models/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .patch_embed import PatchEmbed
|
| 2 |
+
from .attention import Attention
|
| 3 |
+
from .layernorm import LayerNorm3d
|
| 4 |
+
from .rotary_pos_embed import RotaryPositionEmbedding
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'PatchEmbed',
|
| 8 |
+
'Attention',
|
| 9 |
+
'LayerNorm3d',
|
| 10 |
+
'RotaryPositionEmbedding',
|
| 11 |
+
]
|
spectre/models/layers/attention.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Type, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.jit import Final
|
| 7 |
+
from timm.layers import use_fused_attn
|
| 8 |
+
|
| 9 |
+
from spectre.models.layers.rotary_pos_embed import rope_apply
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Attention(nn.Module):
|
| 13 |
+
fused_attn: Final[bool]
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
num_heads: int = 8,
|
| 19 |
+
mode: str = "mha",
|
| 20 |
+
q_proj_dim: Optional[int] = None,
|
| 21 |
+
kv_proj_dim: Optional[int] = None,
|
| 22 |
+
qkv_bias: bool = False,
|
| 23 |
+
qk_norm: bool = False,
|
| 24 |
+
proj_bias: bool = True,
|
| 25 |
+
attn_drop: float = 0.,
|
| 26 |
+
proj_drop: float = 0.,
|
| 27 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 28 |
+
) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 31 |
+
self.num_heads = num_heads
|
| 32 |
+
self.head_dim = dim // num_heads
|
| 33 |
+
self.scale = self.head_dim ** -0.5
|
| 34 |
+
self.fused_attn = use_fused_attn()
|
| 35 |
+
self.mode = mode.lower()
|
| 36 |
+
assert self.mode in ["mha", "mqa", "mla"], "Attention mode must be 'mha', 'mqa', or 'mla'"
|
| 37 |
+
assert not (self.mode == "mla" and kv_proj_dim is None), "kv_proj_dim must be provided for 'mla' mode"
|
| 38 |
+
assert not (self.mode == "mla" and q_proj_dim is None), "q_proj_dim must be provided for 'mla' mode"
|
| 39 |
+
|
| 40 |
+
if self.mode == "mha":
|
| 41 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 42 |
+
self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head
|
| 43 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 44 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 45 |
+
elif self.mode == "mqa":
|
| 46 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 47 |
+
self.kv = nn.Linear(dim, 2 * self.head_dim, bias=qkv_bias) # Key and value pair shared across heads
|
| 48 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 49 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 50 |
+
elif self.mode == "mla":
|
| 51 |
+
self.q_proj = nn.Linear(dim, q_proj_dim, bias=qkv_bias) # Projected query for every head
|
| 52 |
+
self.kv_proj = nn.Linear(dim, kv_proj_dim, bias=qkv_bias) # Projected key and value pair for every head
|
| 53 |
+
self.q_norm = norm_layer(q_proj_dim) if qk_norm else nn.Identity()
|
| 54 |
+
self.kv_norm = norm_layer(kv_proj_dim) if qk_norm else nn.Identity()
|
| 55 |
+
self.q = nn.Linear(q_proj_dim, dim, bias=qkv_bias) # Query for every head
|
| 56 |
+
self.kv = nn.Linear(kv_proj_dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head
|
| 57 |
+
|
| 58 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 59 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 60 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 61 |
+
|
| 62 |
+
def apply_rotary_pos_emb(
|
| 63 |
+
self,
|
| 64 |
+
q: torch.Tensor,
|
| 65 |
+
k: torch.Tensor,
|
| 66 |
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
| 67 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 68 |
+
"""
|
| 69 |
+
Apply RoPE to the query and key tensors.
|
| 70 |
+
Args:
|
| 71 |
+
q (torch.Tensor): Query tensor of shape (B, num_heads, N, head_dim)
|
| 72 |
+
k (torch.Tensor): Key tensor of shape (B, num_heads, N, head_dim)
|
| 73 |
+
rope (Tuple[torch.Tensor, torch.Tensor]): Tuple of (sin, cos) tensors for RoPE application.
|
| 74 |
+
Sin and cos can be of shape
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
# Match dtype to rope for numeric stability
|
| 78 |
+
q_dtype, k_dtype = q.dtype, k.dtype
|
| 79 |
+
sin, cos = rope
|
| 80 |
+
|
| 81 |
+
if sin.ndim == 2:
|
| 82 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 83 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 84 |
+
elif sin.ndim == 3:
|
| 85 |
+
sin = sin.unsqueeze(1)
|
| 86 |
+
cos = cos.unsqueeze(1)
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError("RoPE sin/cos must be of shape [N, head_dim] or [B, N, head_dim]")
|
| 89 |
+
|
| 90 |
+
rope_dtype = sin.dtype
|
| 91 |
+
|
| 92 |
+
q = q.to(dtype=rope_dtype)
|
| 93 |
+
k = k.to(dtype=rope_dtype)
|
| 94 |
+
|
| 95 |
+
N = q.shape[-2] # total tokens per sample
|
| 96 |
+
N_spatial = sin.shape[-2] # number of spatial tokens covered by rope
|
| 97 |
+
prefix = N - N_spatial # e.g., [cls] or [reg] tokens at the front
|
| 98 |
+
assert prefix >= 0, "RoPE sin/cos length exceeds sequence length"
|
| 99 |
+
|
| 100 |
+
if prefix > 0:
|
| 101 |
+
q_prefix = q[:, :, :prefix, :]
|
| 102 |
+
k_prefix = k[:, :, :prefix, :]
|
| 103 |
+
q_spatial = q[:, :, prefix:, :]
|
| 104 |
+
k_spatial = k[:, :, prefix:, :]
|
| 105 |
+
else:
|
| 106 |
+
q_prefix = k_prefix = None
|
| 107 |
+
q_spatial, k_spatial = q, k
|
| 108 |
+
|
| 109 |
+
# Apply RoPE on the spatial tail
|
| 110 |
+
q_spatial = rope_apply(q_spatial, sin, cos)
|
| 111 |
+
k_spatial = rope_apply(k_spatial, sin, cos)
|
| 112 |
+
|
| 113 |
+
# Stitch back
|
| 114 |
+
if prefix > 0:
|
| 115 |
+
q = torch.cat((q_prefix, q_spatial), dim=-2)
|
| 116 |
+
k = torch.cat((k_prefix, k_spatial), dim=-2)
|
| 117 |
+
else:
|
| 118 |
+
q, k = q_spatial, k_spatial
|
| 119 |
+
|
| 120 |
+
# Cast back to original dtypes
|
| 121 |
+
q = q.to(dtype=q_dtype)
|
| 122 |
+
k = k.to(dtype=k_dtype)
|
| 123 |
+
|
| 124 |
+
return q, k
|
| 125 |
+
|
| 126 |
+
def compute_qkv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 127 |
+
B, N, _ = x.shape
|
| 128 |
+
if self.mode == "mha":
|
| 129 |
+
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 130 |
+
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 131 |
+
k, v = kv.unbind(0)
|
| 132 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 133 |
+
elif self.mode == "mqa":
|
| 134 |
+
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 135 |
+
kv = self.kv(x).reshape(B, N, 2, 1, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 136 |
+
kv = kv.expand(-1, -1, self.num_heads, -1, -1) # Expand to match num_heads
|
| 137 |
+
k, v = kv.unbind(0)
|
| 138 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 139 |
+
elif self.mode == "mla":
|
| 140 |
+
q = self.q_proj(x)
|
| 141 |
+
kv = self.kv_proj(x)
|
| 142 |
+
q, kv = self.q_norm(q), self.kv_norm(kv) # Normalization on projections
|
| 143 |
+
q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 144 |
+
kv = self.kv(kv).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 145 |
+
k, v = kv.unbind(0)
|
| 146 |
+
return q, k, v
|
| 147 |
+
|
| 148 |
+
def compute_attention(
|
| 149 |
+
self,
|
| 150 |
+
q: torch.Tensor,
|
| 151 |
+
k: torch.Tensor,
|
| 152 |
+
v: torch.Tensor,
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
B, _, N, _ = q.shape
|
| 155 |
+
C = self.num_heads * self.head_dim
|
| 156 |
+
|
| 157 |
+
if self.fused_attn:
|
| 158 |
+
x = F.scaled_dot_product_attention(
|
| 159 |
+
q, k, v,
|
| 160 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
q = q * self.scale
|
| 164 |
+
attn = q @ k.transpose(-2, -1)
|
| 165 |
+
attn = attn.softmax(dim=-1)
|
| 166 |
+
attn = self.attn_drop(attn)
|
| 167 |
+
x = attn @ v
|
| 168 |
+
|
| 169 |
+
return x.transpose(1, 2).reshape(B, N, C)
|
| 170 |
+
|
| 171 |
+
def forward(
|
| 172 |
+
self,
|
| 173 |
+
x: torch.Tensor,
|
| 174 |
+
rope = None,
|
| 175 |
+
) -> torch.Tensor:
|
| 176 |
+
|
| 177 |
+
q, k, v = self.compute_qkv(x)
|
| 178 |
+
|
| 179 |
+
if rope is not None:
|
| 180 |
+
if isinstance(rope, list):
|
| 181 |
+
rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
|
| 182 |
+
q, k = self.apply_rotary_pos_emb(q, k, rope)
|
| 183 |
+
|
| 184 |
+
x = self.compute_attention(q, k, v)
|
| 185 |
+
x = self.proj(x)
|
| 186 |
+
x = self.proj_drop(x)
|
| 187 |
+
return x
|
spectre/models/layers/layernorm.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from timm.layers.fast_norm import is_fast_norm, fast_layer_norm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerNorm3d(nn.LayerNorm):
|
| 8 |
+
""" LayerNorm for channels of '3D' spatial NCHWD tensors """
|
| 9 |
+
_fast_norm: torch.jit.Final[bool]
|
| 10 |
+
|
| 11 |
+
def __init__(self, num_channels, eps=1e-6, affine=True):
|
| 12 |
+
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
| 13 |
+
self._fast_norm = is_fast_norm() # Assuming is_fast_norm() is defined somewhere
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
x = x.permute(0, 2, 3, 4, 1) # Permute to NCDHW format
|
| 17 |
+
if self._fast_norm:
|
| 18 |
+
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 19 |
+
else:
|
| 20 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 21 |
+
x = x.permute(0, 4, 1, 2, 3) # Permute back to NCHWD format
|
| 22 |
+
return x
|
| 23 |
+
|
spectre/models/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple, Union, Callable
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from spectre.utils import (
|
| 9 |
+
to_3tuple,
|
| 10 |
+
resample_patch_embed,
|
| 11 |
+
Format,
|
| 12 |
+
nchwd_to,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PatchEmbed(nn.Module):
|
| 17 |
+
""" 3D Image to Patch Embedding
|
| 18 |
+
"""
|
| 19 |
+
output_fmt: Format
|
| 20 |
+
dynamic_img_pad: torch.jit.Final[bool]
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
img_size: Optional[Union[int, Tuple[int, int, int]]] = (128, 128, 64),
|
| 25 |
+
patch_size: Union[int, Tuple[int, int, int]] = (16, 16, 8),
|
| 26 |
+
in_chans: int = 1,
|
| 27 |
+
embed_dim: int = 768,
|
| 28 |
+
norm_layer: Optional[Callable] = None,
|
| 29 |
+
flatten: bool = True,
|
| 30 |
+
output_fmt: Optional[str] = None,
|
| 31 |
+
bias: bool = True,
|
| 32 |
+
strict_img_size: bool = True,
|
| 33 |
+
dynamic_img_pad: bool = False,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.patch_size = to_3tuple(patch_size)
|
| 37 |
+
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
| 38 |
+
|
| 39 |
+
if output_fmt is not None:
|
| 40 |
+
self.flatten = False
|
| 41 |
+
self.output_fmt = Format(output_fmt)
|
| 42 |
+
else:
|
| 43 |
+
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
| 44 |
+
self.flatten = flatten
|
| 45 |
+
self.output_fmt = Format.NCHWD
|
| 46 |
+
self.strict_img_size = strict_img_size
|
| 47 |
+
self.dynamic_img_pad = dynamic_img_pad
|
| 48 |
+
|
| 49 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
| 50 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 51 |
+
|
| 52 |
+
def _init_img_size(self, img_size: Union[int, Tuple[int, int, int]]):
|
| 53 |
+
assert self.patch_size
|
| 54 |
+
if img_size is None:
|
| 55 |
+
return None, None, None
|
| 56 |
+
img_size = to_3tuple(img_size)
|
| 57 |
+
grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
|
| 58 |
+
num_patches = grid_size[0] * grid_size[1] * grid_size[2]
|
| 59 |
+
return img_size, grid_size, num_patches
|
| 60 |
+
|
| 61 |
+
def set_input_size(
|
| 62 |
+
self,
|
| 63 |
+
img_size: Optional[Union[int, Tuple[int, int, int]]] = None,
|
| 64 |
+
patch_size: Optional[Union[int, Tuple[int, int, int]]] = None,
|
| 65 |
+
):
|
| 66 |
+
new_patch_size = None
|
| 67 |
+
if patch_size is not None:
|
| 68 |
+
new_patch_size = to_3tuple(patch_size)
|
| 69 |
+
if new_patch_size is not None and new_patch_size != self.patch_size:
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
new_proj = nn.Conv3d(
|
| 72 |
+
self.proj.in_channels,
|
| 73 |
+
self.proj.out_channels,
|
| 74 |
+
kernel_size=new_patch_size,
|
| 75 |
+
stride=new_patch_size,
|
| 76 |
+
bias=self.proj.bias is not None,
|
| 77 |
+
)
|
| 78 |
+
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
|
| 79 |
+
if self.proj.bias is not None:
|
| 80 |
+
new_proj.bias.copy_(self.proj.bias)
|
| 81 |
+
self.proj = new_proj
|
| 82 |
+
self.patch_size = new_patch_size
|
| 83 |
+
img_size = img_size or self.img_size
|
| 84 |
+
if img_size != self.img_size or new_patch_size is not None:
|
| 85 |
+
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
| 86 |
+
|
| 87 |
+
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int, int], int]:
|
| 88 |
+
if as_scalar:
|
| 89 |
+
return max(self.patch_size)
|
| 90 |
+
else:
|
| 91 |
+
return self.patch_size
|
| 92 |
+
|
| 93 |
+
def dynamic_feat_size(self, img_size: Tuple[int, int, int]) -> Tuple[int, int, int]:
|
| 94 |
+
""" Get grid (feature) size for given image size taking account of dynamic padding.
|
| 95 |
+
NOTE: must be torchscript compatible so using fixed tuple indexing
|
| 96 |
+
"""
|
| 97 |
+
if self.dynamic_img_pad:
|
| 98 |
+
return (
|
| 99 |
+
math.ceil(img_size[0] / self.patch_size[0]),
|
| 100 |
+
math.ceil(img_size[1] / self.patch_size[1]),
|
| 101 |
+
math.ceil(img_size[2] / self.patch_size[2]),
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
return (
|
| 105 |
+
img_size[0] // self.patch_size[0],
|
| 106 |
+
img_size[1] // self.patch_size[1],
|
| 107 |
+
img_size[2] // self.patch_size[2],
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
_, _, H, W, D = x.shape
|
| 112 |
+
if self.img_size is not None:
|
| 113 |
+
if self.strict_img_size:
|
| 114 |
+
assert H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]})."
|
| 115 |
+
assert W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]})."
|
| 116 |
+
assert D == self.img_size[2], f"Input depth ({D}) doesn't match model ({self.img_size[2]})."
|
| 117 |
+
elif not self.dynamic_img_pad:
|
| 118 |
+
assert H % self.patch_size[0] == 0, \
|
| 119 |
+
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
| 120 |
+
assert W % self.patch_size[1] == 0, \
|
| 121 |
+
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
| 122 |
+
assert D % self.patch_size[2] == 0, \
|
| 123 |
+
f"Input depth ({D}) should be divisible by patch size ({self.patch_size[2]})."
|
| 124 |
+
if self.dynamic_img_pad:
|
| 125 |
+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
| 126 |
+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
| 127 |
+
pad_d = (self.patch_size[2] - D % self.patch_size[2]) % self.patch_size[2]
|
| 128 |
+
x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
|
| 129 |
+
x = self.proj(x)
|
| 130 |
+
if self.flatten:
|
| 131 |
+
x = x.flatten(2).transpose(1, 2) # NCHWD -> NLC
|
| 132 |
+
elif self.output_fmt != Format.NCHWD:
|
| 133 |
+
x = nchwd_to(x, self.output_fmt)
|
| 134 |
+
x = self.norm(x)
|
| 135 |
+
return x
|
spectre/models/layers/rotary_pos_embed.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Literal, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
# x: [..., D], split into halves and rotate
|
| 11 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 12 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rope_apply(
|
| 16 |
+
x: torch.Tensor,
|
| 17 |
+
sin: torch.Tensor,
|
| 18 |
+
cos: torch.Tensor
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
# x, sin, cos: [..., D]
|
| 21 |
+
return (x * cos) + (rope_rotate_half(x) * sin)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class RotaryPositionEmbedding(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
3D Rotary Positional Embedding (RoPE) with no mixing across axes (axial),
|
| 27 |
+
and no learnable weights. Allows for shifting and scaling of the positional encodings
|
| 28 |
+
for improving performance on varying resolutions.
|
| 29 |
+
Mirrors DINOv3 style but for (H, W, D).
|
| 30 |
+
|
| 31 |
+
Requirements:
|
| 32 |
+
- head_dim % 6 == 0 (because 3 axes -> periods of size head_dim//6, then we tile to fill head_dim)
|
| 33 |
+
|
| 34 |
+
Two parametrizations:
|
| 35 |
+
* base
|
| 36 |
+
* min_period + max_period
|
| 37 |
+
"""
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
embed_dim: int,
|
| 41 |
+
*,
|
| 42 |
+
num_heads: int,
|
| 43 |
+
base: float | None = 1000.0, # works for common 8^3=512 to 16^3=4096 tokens
|
| 44 |
+
min_period: float | None = None,
|
| 45 |
+
max_period: float | None = None,
|
| 46 |
+
normalize_coords: Literal["min", "max", "separate"] = "separate",
|
| 47 |
+
shift_coords: float | None = None,
|
| 48 |
+
jitter_coords: float | None = None,
|
| 49 |
+
rescale_coords: float | None = None,
|
| 50 |
+
dtype: torch.dtype | None = None,
|
| 51 |
+
device: torch.device | None = None,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 55 |
+
head_dim = embed_dim // num_heads
|
| 56 |
+
assert head_dim % 6 == 0, "For 3D RoPE, (embed_dim // num_heads) must be divisible by 6"
|
| 57 |
+
|
| 58 |
+
both_periods = (min_period is not None) and (max_period is not None)
|
| 59 |
+
if (base is None and not both_periods) or (base is not None and both_periods):
|
| 60 |
+
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
|
| 61 |
+
|
| 62 |
+
self.base = base
|
| 63 |
+
self.min_period = min_period
|
| 64 |
+
self.max_period = max_period
|
| 65 |
+
self.head_dim = head_dim
|
| 66 |
+
self.normalize_coords = normalize_coords
|
| 67 |
+
self.shift_coords = shift_coords
|
| 68 |
+
self.jitter_coords = jitter_coords
|
| 69 |
+
self.rescale_coords = rescale_coords
|
| 70 |
+
|
| 71 |
+
# Keep dtype persistent so teacher can be initialized from student state_dict()
|
| 72 |
+
self.dtype = dtype
|
| 73 |
+
self.register_buffer(
|
| 74 |
+
"periods",
|
| 75 |
+
torch.empty(self.head_dim // 6, device=device, dtype=dtype),
|
| 76 |
+
persistent=True,
|
| 77 |
+
)
|
| 78 |
+
self._init_weights()
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def _init_weights(self):
|
| 82 |
+
device = self.periods.device
|
| 83 |
+
dtype = self.dtype
|
| 84 |
+
if self.base is not None:
|
| 85 |
+
# powers from 0..(head_dim // 3 - 1), normalized to [0, 1) across head_dim // 3?
|
| 86 |
+
# for 3D we use // 6 per axis
|
| 87 |
+
periods = self.base ** (
|
| 88 |
+
2 * torch.arange(self.head_dim // 6, device=device, dtype=dtype) / (self.head_dim // 3)
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
# geometric spacing between min_period and max_period
|
| 92 |
+
base = self.max_period / self.min_period
|
| 93 |
+
exponents = torch.linspace(0, 1, self.head_dim // 6, device=device, dtype=dtype)
|
| 94 |
+
periods = base ** exponents
|
| 95 |
+
periods = periods / base
|
| 96 |
+
periods = periods * self.max_period
|
| 97 |
+
self.periods.data = periods
|
| 98 |
+
|
| 99 |
+
def forward(self, *, H: int, W: int, D: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 100 |
+
"""
|
| 101 |
+
Returns:
|
| 102 |
+
sin, cos: [H * W * D, head_dim] (per-head)
|
| 103 |
+
"""
|
| 104 |
+
device = self.periods.device
|
| 105 |
+
dtype = self.dtype
|
| 106 |
+
dd = dict(device=device, dtype=dtype)
|
| 107 |
+
|
| 108 |
+
# Prepare coords in [0, 1] then map to [-1, +1]
|
| 109 |
+
if self.normalize_coords == "max":
|
| 110 |
+
max_dim = max(H, W, D)
|
| 111 |
+
coords_h = torch.arange(0.5, H, **dd) / max_dim
|
| 112 |
+
coords_w = torch.arange(0.5, W, **dd) / max_dim
|
| 113 |
+
coords_d = torch.arange(0.5, D, **dd) / max_dim
|
| 114 |
+
elif self.normalize_coords == "min":
|
| 115 |
+
min_dim = min(H, W, D)
|
| 116 |
+
coords_h = torch.arange(0.5, H, **dd) / min_dim
|
| 117 |
+
coords_w = torch.arange(0.5, W, **dd) / min_dim
|
| 118 |
+
coords_d = torch.arange(0.5, D, **dd) / min_dim
|
| 119 |
+
elif self.normalize_coords == "separate":
|
| 120 |
+
coords_h = torch.arange(0.5, H, **dd) / H
|
| 121 |
+
coords_w = torch.arange(0.5, W, **dd) / W
|
| 122 |
+
coords_d = torch.arange(0.5, D, **dd) / D
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
|
| 125 |
+
|
| 126 |
+
coords = torch.stack(
|
| 127 |
+
torch.meshgrid(coords_h, coords_w, coords_d, indexing="ij"),
|
| 128 |
+
dim=-1
|
| 129 |
+
) # [H, W, D, 3]
|
| 130 |
+
coords = coords.flatten(0, 2) # [HWD, 3]
|
| 131 |
+
coords = 2.0 * coords - 1.0 # [-1, +1]
|
| 132 |
+
|
| 133 |
+
# Optional train-time augmentations on coords (DINOv3)
|
| 134 |
+
if self.training and self.shift_coords is not None:
|
| 135 |
+
shift_hwd = torch.empty(3, **dd).uniform_(-self.shift_coords, self.shift_coords)
|
| 136 |
+
coords = coords + shift_hwd[None, :]
|
| 137 |
+
|
| 138 |
+
if self.training and self.jitter_coords is not None:
|
| 139 |
+
jit_max = np.log(self.jitter_coords); jit_min = -jit_max
|
| 140 |
+
jitter = torch.empty(3, **dd).uniform_(jit_min, jit_max).exp()
|
| 141 |
+
coords = coords * jitter[None, :]
|
| 142 |
+
|
| 143 |
+
if self.training and self.rescale_coords is not None:
|
| 144 |
+
r_max = np.log(self.rescale_coords); r_min = -r_max
|
| 145 |
+
rescale = torch.empty(1, **dd).uniform_(r_min, r_max).exp()
|
| 146 |
+
coords = coords * rescale
|
| 147 |
+
|
| 148 |
+
# --- Build angles per axis, then concatenate across axes ---
|
| 149 |
+
# coords: [N, 3] ; periods: [head_dim // 6]
|
| 150 |
+
# angles: [N, 3, head_dim // 6] -> flatten(1, 2) -> [N, head_dim // 2] -> tile(2) -> [N, head_dim]
|
| 151 |
+
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [N, 3, head_dim // 6]
|
| 152 |
+
angles = angles.flatten(1, 2) # [N, head_dim // 2]
|
| 153 |
+
angles = angles.tile(2) # [N, head_dim]
|
| 154 |
+
|
| 155 |
+
cos = torch.cos(angles)
|
| 156 |
+
sin = torch.sin(angles)
|
| 157 |
+
return sin, cos
|
spectre/models/resnet.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
from urllib.parse import urlparse
|
| 4 |
+
from typing import Type, Any, Tuple, List, Optional, Union, Dict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from spectre.utils import to_ntuple
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
|
| 14 |
+
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
| 15 |
+
return padding
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BasicBlock(nn.Module):
|
| 19 |
+
expansion = 1
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
inplanes: int,
|
| 24 |
+
planes: int,
|
| 25 |
+
stride: int = 1,
|
| 26 |
+
downsample: Optional[nn.Module] = None,
|
| 27 |
+
cardinality: int = 1,
|
| 28 |
+
base_width: int = 64,
|
| 29 |
+
reduce_first: int = 1,
|
| 30 |
+
dilation: int = 1,
|
| 31 |
+
first_dilation: Optional[int] = None,
|
| 32 |
+
act_layer: Type[nn.Module] = nn.ReLU,
|
| 33 |
+
norm_layer: Type[nn.Module] = nn.BatchNorm3d,
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
inplanes: Input channel dimensionality.
|
| 38 |
+
planes: Used to determine output channel dimensionalities.
|
| 39 |
+
stride: Stride used in convolution layers.
|
| 40 |
+
downsample: Optional downsample layer for residual path.
|
| 41 |
+
cardinality: Number of convolution groups.
|
| 42 |
+
base_width: Base width used to determine output channel dimensionality.
|
| 43 |
+
reduce_first: Reduction factor for first convolution output width of residual blocks.
|
| 44 |
+
dilation: Dilation rate for convolution layers.
|
| 45 |
+
first_dilation: Dilation rate for first convolution layer.
|
| 46 |
+
act_layer: Activation layer.
|
| 47 |
+
norm_layer: Normalization layer.
|
| 48 |
+
"""
|
| 49 |
+
super(BasicBlock, self).__init__()
|
| 50 |
+
|
| 51 |
+
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
| 52 |
+
assert base_width == 64, 'BasicBlock does not support changing base width'
|
| 53 |
+
first_planes = planes // reduce_first
|
| 54 |
+
outplanes = planes * self.expansion
|
| 55 |
+
first_dilation = first_dilation or dilation
|
| 56 |
+
|
| 57 |
+
self.conv1 = nn.Conv3d(
|
| 58 |
+
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
| 59 |
+
dilation=first_dilation, bias=False)
|
| 60 |
+
self.bn1 = norm_layer(first_planes)
|
| 61 |
+
self.act1 = act_layer()
|
| 62 |
+
|
| 63 |
+
self.conv2 = nn.Conv3d(
|
| 64 |
+
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
| 65 |
+
self.bn2 = norm_layer(outplanes)
|
| 66 |
+
|
| 67 |
+
self.act2 = act_layer()
|
| 68 |
+
self.downsample = downsample
|
| 69 |
+
self.stride = stride
|
| 70 |
+
self.dilation = dilation
|
| 71 |
+
|
| 72 |
+
def zero_init_last(self):
|
| 73 |
+
if getattr(self.bn2, 'weight', None) is not None:
|
| 74 |
+
nn.init.zeros_(self.bn2.weight)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
shortcut = x
|
| 78 |
+
|
| 79 |
+
x = self.conv1(x)
|
| 80 |
+
x = self.bn1(x)
|
| 81 |
+
x = self.act1(x)
|
| 82 |
+
|
| 83 |
+
x = self.conv2(x)
|
| 84 |
+
x = self.bn2(x)
|
| 85 |
+
|
| 86 |
+
if self.downsample is not None:
|
| 87 |
+
shortcut = self.downsample(shortcut)
|
| 88 |
+
x = x + shortcut
|
| 89 |
+
x = self.act2(x)
|
| 90 |
+
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Bottleneck(nn.Module):
|
| 95 |
+
expansion = 4
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
inplanes: int,
|
| 100 |
+
planes: int,
|
| 101 |
+
stride: int = 1,
|
| 102 |
+
downsample: Optional[nn.Module] = None,
|
| 103 |
+
cardinality: int = 1,
|
| 104 |
+
base_width: int = 64,
|
| 105 |
+
reduce_first: int = 1,
|
| 106 |
+
dilation: int = 1,
|
| 107 |
+
first_dilation: Optional[int] = None,
|
| 108 |
+
act_layer: Type[nn.Module] = nn.ReLU,
|
| 109 |
+
norm_layer: Type[nn.Module] = nn.BatchNorm3d,
|
| 110 |
+
):
|
| 111 |
+
"""
|
| 112 |
+
Args:
|
| 113 |
+
inplanes: Input channel dimensionality.
|
| 114 |
+
planes: Used to determine output channel dimensionalities.
|
| 115 |
+
stride: Stride used in convolution layers.
|
| 116 |
+
downsample: Optional downsample layer for residual path.
|
| 117 |
+
cardinality: Number of convolution groups.
|
| 118 |
+
base_width: Base width used to determine output channel dimensionality.
|
| 119 |
+
reduce_first: Reduction factor for first convolution output width of residual blocks.
|
| 120 |
+
dilation: Dilation rate for convolution layers.
|
| 121 |
+
first_dilation: Dilation rate for first convolution layer.
|
| 122 |
+
act_layer: Activation layer.
|
| 123 |
+
norm_layer: Normalization layer.
|
| 124 |
+
"""
|
| 125 |
+
super(Bottleneck, self).__init__()
|
| 126 |
+
|
| 127 |
+
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
| 128 |
+
first_planes = width // reduce_first
|
| 129 |
+
outplanes = planes * self.expansion
|
| 130 |
+
first_dilation = first_dilation or dilation
|
| 131 |
+
|
| 132 |
+
self.conv1 = nn.Conv3d(inplanes, first_planes, kernel_size=1, bias=False)
|
| 133 |
+
self.bn1 = norm_layer(first_planes)
|
| 134 |
+
self.act1 = act_layer()
|
| 135 |
+
|
| 136 |
+
self.conv2 = nn.Conv3d(
|
| 137 |
+
first_planes, width, kernel_size=3, stride=stride,
|
| 138 |
+
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
| 139 |
+
self.bn2 = norm_layer(width)
|
| 140 |
+
self.act2 = act_layer()
|
| 141 |
+
|
| 142 |
+
self.conv3 = nn.Conv3d(width, outplanes, kernel_size=1, bias=False)
|
| 143 |
+
self.bn3 = norm_layer(outplanes)
|
| 144 |
+
|
| 145 |
+
self.act3 = act_layer()
|
| 146 |
+
self.downsample = downsample
|
| 147 |
+
self.stride = stride
|
| 148 |
+
self.dilation = dilation
|
| 149 |
+
|
| 150 |
+
def zero_init_last(self):
|
| 151 |
+
if getattr(self.bn3, 'weight', None) is not None:
|
| 152 |
+
nn.init.zeros_(self.bn3.weight)
|
| 153 |
+
|
| 154 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
shortcut = x
|
| 156 |
+
|
| 157 |
+
x = self.conv1(x)
|
| 158 |
+
x = self.bn1(x)
|
| 159 |
+
x = self.act1(x)
|
| 160 |
+
|
| 161 |
+
x = self.conv2(x)
|
| 162 |
+
x = self.bn2(x)
|
| 163 |
+
x = self.act2(x)
|
| 164 |
+
|
| 165 |
+
x = self.conv3(x)
|
| 166 |
+
x = self.bn3(x)
|
| 167 |
+
|
| 168 |
+
if self.downsample is not None:
|
| 169 |
+
shortcut = self.downsample(shortcut)
|
| 170 |
+
x = x + shortcut
|
| 171 |
+
x = self.act3(x)
|
| 172 |
+
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def downsample_conv(
|
| 177 |
+
in_channels: int,
|
| 178 |
+
out_channels: int,
|
| 179 |
+
kernel_size: int,
|
| 180 |
+
stride: int = 1,
|
| 181 |
+
dilation: int = 1,
|
| 182 |
+
first_dilation: Optional[int] = None,
|
| 183 |
+
norm_layer: Optional[Type[nn.Module]] = None,
|
| 184 |
+
) -> nn.Module:
|
| 185 |
+
norm_layer = norm_layer or nn.BatchNorm3d
|
| 186 |
+
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
|
| 187 |
+
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
|
| 188 |
+
p = get_padding(kernel_size, stride, first_dilation)
|
| 189 |
+
|
| 190 |
+
return nn.Sequential(*[
|
| 191 |
+
nn.Conv3d(
|
| 192 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
|
| 193 |
+
norm_layer(out_channels)
|
| 194 |
+
])
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def downsample_avg(
|
| 198 |
+
in_channels: int,
|
| 199 |
+
out_channels: int,
|
| 200 |
+
kernel_size: int,
|
| 201 |
+
stride: int = 1,
|
| 202 |
+
dilation: int = 1,
|
| 203 |
+
first_dilation: Optional[int] = None,
|
| 204 |
+
norm_layer: Optional[Type[nn.Module]] = None,
|
| 205 |
+
) -> nn.Module:
|
| 206 |
+
norm_layer = norm_layer or nn.BatchNorm3d
|
| 207 |
+
avg_stride = stride if dilation == 1 else 1
|
| 208 |
+
if stride == 1 and dilation == 1:
|
| 209 |
+
pool = nn.Identity()
|
| 210 |
+
else:
|
| 211 |
+
pool = nn.AvgPool3d(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
| 212 |
+
|
| 213 |
+
return nn.Sequential(*[
|
| 214 |
+
pool,
|
| 215 |
+
nn.Conv3d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
|
| 216 |
+
norm_layer(out_channels)
|
| 217 |
+
])
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def make_blocks(
|
| 221 |
+
block_fns: Tuple[Union[BasicBlock, Bottleneck]],
|
| 222 |
+
channels: Tuple[int, ...],
|
| 223 |
+
block_repeats: Tuple[int, ...],
|
| 224 |
+
inplanes: int,
|
| 225 |
+
reduce_first: int = 1,
|
| 226 |
+
output_stride: int = 32,
|
| 227 |
+
down_kernel_size: int = 1,
|
| 228 |
+
avg_down: bool = False,
|
| 229 |
+
**kwargs,
|
| 230 |
+
) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
|
| 231 |
+
stages = []
|
| 232 |
+
feature_info = []
|
| 233 |
+
net_num_blocks = sum(block_repeats)
|
| 234 |
+
net_block_idx = 0
|
| 235 |
+
net_stride = 4
|
| 236 |
+
dilation = prev_dilation = 1
|
| 237 |
+
for stage_idx, (block_fn, planes, num_blocks) in enumerate(zip(block_fns, channels, block_repeats)):
|
| 238 |
+
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
|
| 239 |
+
stride = 1 if stage_idx == 0 else 2
|
| 240 |
+
if net_stride >= output_stride:
|
| 241 |
+
dilation *= stride
|
| 242 |
+
stride = 1
|
| 243 |
+
else:
|
| 244 |
+
net_stride *= stride
|
| 245 |
+
|
| 246 |
+
downsample = None
|
| 247 |
+
if stride != 1 or inplanes != planes * block_fn.expansion:
|
| 248 |
+
down_kwargs = dict(
|
| 249 |
+
in_channels=inplanes,
|
| 250 |
+
out_channels=planes * block_fn.expansion,
|
| 251 |
+
kernel_size=down_kernel_size,
|
| 252 |
+
stride=stride,
|
| 253 |
+
dilation=dilation,
|
| 254 |
+
first_dilation=prev_dilation,
|
| 255 |
+
norm_layer=kwargs.get('norm_layer'),
|
| 256 |
+
)
|
| 257 |
+
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
|
| 258 |
+
|
| 259 |
+
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, **kwargs)
|
| 260 |
+
blocks = []
|
| 261 |
+
for block_idx in range(num_blocks):
|
| 262 |
+
downsample = downsample if block_idx == 0 else None
|
| 263 |
+
stride = stride if block_idx == 0 else 1
|
| 264 |
+
blocks.append(block_fn(
|
| 265 |
+
inplanes,
|
| 266 |
+
planes,
|
| 267 |
+
stride,
|
| 268 |
+
downsample,
|
| 269 |
+
first_dilation=prev_dilation,
|
| 270 |
+
**block_kwargs,
|
| 271 |
+
))
|
| 272 |
+
prev_dilation = dilation
|
| 273 |
+
inplanes = planes * block_fn.expansion
|
| 274 |
+
net_block_idx += 1
|
| 275 |
+
|
| 276 |
+
stages.append((stage_name, nn.Sequential(*blocks)))
|
| 277 |
+
feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
|
| 278 |
+
|
| 279 |
+
return stages, feature_info
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def feature_take_indices(
|
| 283 |
+
num_features: int,
|
| 284 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 285 |
+
as_set: bool = False,
|
| 286 |
+
) -> Tuple[List[int], int]:
|
| 287 |
+
""" Determine the absolute feature indices to 'take' from.
|
| 288 |
+
|
| 289 |
+
Note: This function can be called in forward() so must be torchscript compatible,
|
| 290 |
+
which requires some incomplete typing and workaround hacks.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
num_features: total number of features to select from
|
| 294 |
+
indices: indices to select,
|
| 295 |
+
None -> select all
|
| 296 |
+
int -> select last n
|
| 297 |
+
list/tuple of int -> return specified (-ve indices specify from end)
|
| 298 |
+
as_set: return as a set
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
List (or set) of absolute (from beginning) indices, Maximum index
|
| 302 |
+
"""
|
| 303 |
+
if indices is None:
|
| 304 |
+
indices = num_features # all features if None
|
| 305 |
+
|
| 306 |
+
if isinstance(indices, int):
|
| 307 |
+
# convert int -> last n indices
|
| 308 |
+
assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})'
|
| 309 |
+
take_indices = [num_features - indices + i for i in range(indices)]
|
| 310 |
+
else:
|
| 311 |
+
take_indices: List[int] = []
|
| 312 |
+
for i in indices:
|
| 313 |
+
idx = num_features + i if i < 0 else i
|
| 314 |
+
assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})'
|
| 315 |
+
take_indices.append(idx)
|
| 316 |
+
|
| 317 |
+
if not torch.jit.is_scripting() and as_set:
|
| 318 |
+
return set(take_indices), max(take_indices)
|
| 319 |
+
|
| 320 |
+
return take_indices, max(take_indices)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class ResNet(nn.Module):
|
| 324 |
+
"""ResNet / ResNeXt
|
| 325 |
+
|
| 326 |
+
This class implements all variants of ResNet, ResNeXt that
|
| 327 |
+
* have > 1 stride in the 3x3 conv layer of bottleneck
|
| 328 |
+
* have conv-bn-act ordering
|
| 329 |
+
|
| 330 |
+
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
|
| 331 |
+
variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
|
| 332 |
+
'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
|
| 333 |
+
|
| 334 |
+
ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
|
| 335 |
+
* normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
|
| 336 |
+
* c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
|
| 337 |
+
* d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
|
| 338 |
+
* e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
|
| 339 |
+
* s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
|
| 340 |
+
* t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
|
| 341 |
+
* tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
|
| 342 |
+
|
| 343 |
+
ResNeXt
|
| 344 |
+
* normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
|
| 345 |
+
* same c,d, e, s variants as ResNet can be enabled
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(
|
| 349 |
+
self,
|
| 350 |
+
block: Union[BasicBlock, Bottleneck],
|
| 351 |
+
layers: Tuple[int, ...],
|
| 352 |
+
num_classes: int = 1000,
|
| 353 |
+
in_chans: int = 1,
|
| 354 |
+
output_stride: int = 32,
|
| 355 |
+
global_pool: str = 'avg',
|
| 356 |
+
cardinality: int = 1,
|
| 357 |
+
base_width: int = 64,
|
| 358 |
+
stem_width: int = 64,
|
| 359 |
+
stem_type: str = '',
|
| 360 |
+
replace_stem_pool: bool = False,
|
| 361 |
+
block_reduce_first: int = 1,
|
| 362 |
+
down_kernel_size: int = 1,
|
| 363 |
+
avg_down: bool = False,
|
| 364 |
+
channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
|
| 365 |
+
act_layer: Type[nn.Module] = nn.ReLU,
|
| 366 |
+
norm_layer: Type[nn.Module] = nn.BatchNorm3d,
|
| 367 |
+
drop_rate: float = 0.0,
|
| 368 |
+
zero_init_last: bool = True,
|
| 369 |
+
block_args: Optional[Dict[str, Any]] = None,
|
| 370 |
+
):
|
| 371 |
+
"""
|
| 372 |
+
Args:
|
| 373 |
+
block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
|
| 374 |
+
layers (List[int]) : number of layers in each block
|
| 375 |
+
num_classes (int): number of classification classes (default 1000)
|
| 376 |
+
in_chans (int): number of input (color) channels. (default 3)
|
| 377 |
+
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
|
| 378 |
+
global_pool (str): Global pooling type. One of 'avg', 'max' (default 'avg')
|
| 379 |
+
cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
|
| 380 |
+
base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
|
| 381 |
+
stem_width (int): number of channels in stem convolutions (default 64)
|
| 382 |
+
stem_type (str): The type of stem (default ''):
|
| 383 |
+
* '', default - a single 7x7 conv with a width of stem_width
|
| 384 |
+
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
|
| 385 |
+
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
|
| 386 |
+
replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
|
| 387 |
+
block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
|
| 388 |
+
1 for all archs except senets, where 2 (default 1)
|
| 389 |
+
down_kernel_size (int): kernel size of residual block downsample path,
|
| 390 |
+
1x1 for most, 3x3 for senets (default: 1)
|
| 391 |
+
avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
|
| 392 |
+
act_layer (str, nn.Module): activation layer
|
| 393 |
+
norm_layer (str, nn.Module): normalization layer
|
| 394 |
+
drop_rate (float): Dropout probability before classifier, for training (default 0.)
|
| 395 |
+
zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
|
| 396 |
+
block_args (dict): Extra kwargs to pass through to block module
|
| 397 |
+
"""
|
| 398 |
+
super(ResNet, self).__init__()
|
| 399 |
+
block_args = block_args or dict()
|
| 400 |
+
assert output_stride in (8, 16, 32)
|
| 401 |
+
self.num_classes = num_classes
|
| 402 |
+
self.drop_rate = drop_rate
|
| 403 |
+
|
| 404 |
+
# Stem
|
| 405 |
+
deep_stem = 'deep' in stem_type
|
| 406 |
+
inplanes = stem_width * 2 if deep_stem else 64
|
| 407 |
+
if deep_stem:
|
| 408 |
+
stem_chs = (stem_width, stem_width)
|
| 409 |
+
if 'tiered' in stem_type:
|
| 410 |
+
stem_chs = (3 * (stem_width // 4), stem_width)
|
| 411 |
+
self.conv1 = nn.Sequential(*[
|
| 412 |
+
nn.Conv3d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
|
| 413 |
+
norm_layer(stem_chs[0]),
|
| 414 |
+
act_layer(),
|
| 415 |
+
nn.Conv3d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
|
| 416 |
+
norm_layer(stem_chs[1]),
|
| 417 |
+
act_layer(),
|
| 418 |
+
nn.Conv3d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
|
| 419 |
+
else:
|
| 420 |
+
self.conv1 = nn.Conv3d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
| 421 |
+
self.bn1 = norm_layer(inplanes)
|
| 422 |
+
self.act1 = act_layer()
|
| 423 |
+
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
|
| 424 |
+
|
| 425 |
+
# Stem pooling. The name 'maxpool' remains for weight compatibility.
|
| 426 |
+
if replace_stem_pool:
|
| 427 |
+
self.maxpool = nn.Sequential(*filter(None, [
|
| 428 |
+
nn.Conv3d(inplanes, inplanes, 3, stride=2, padding=1, bias=False),
|
| 429 |
+
norm_layer(inplanes),
|
| 430 |
+
act_layer(),
|
| 431 |
+
]))
|
| 432 |
+
else:
|
| 433 |
+
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
|
| 434 |
+
|
| 435 |
+
# Feature Blocks
|
| 436 |
+
block_fns = to_ntuple(len(channels))(block)
|
| 437 |
+
stage_modules, stage_feature_info = make_blocks(
|
| 438 |
+
block_fns,
|
| 439 |
+
channels,
|
| 440 |
+
layers,
|
| 441 |
+
inplanes,
|
| 442 |
+
cardinality=cardinality,
|
| 443 |
+
base_width=base_width,
|
| 444 |
+
output_stride=output_stride,
|
| 445 |
+
reduce_first=block_reduce_first,
|
| 446 |
+
avg_down=avg_down,
|
| 447 |
+
down_kernel_size=down_kernel_size,
|
| 448 |
+
act_layer=act_layer,
|
| 449 |
+
norm_layer=norm_layer,
|
| 450 |
+
**block_args,
|
| 451 |
+
)
|
| 452 |
+
for stage in stage_modules:
|
| 453 |
+
self.add_module(*stage) # layer1, layer2, etc
|
| 454 |
+
self.feature_info.extend(stage_feature_info)
|
| 455 |
+
|
| 456 |
+
# Head (Pooling and Classifier)
|
| 457 |
+
self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
|
| 458 |
+
if global_pool == 'avg':
|
| 459 |
+
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
| 460 |
+
elif global_pool == 'max':
|
| 461 |
+
self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
|
| 462 |
+
else:
|
| 463 |
+
raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
|
| 464 |
+
|
| 465 |
+
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 466 |
+
|
| 467 |
+
self.init_weights(zero_init_last=zero_init_last)
|
| 468 |
+
|
| 469 |
+
@torch.jit.ignore
|
| 470 |
+
def init_weights(self, zero_init_last: bool = True):
|
| 471 |
+
for n, m in self.named_modules():
|
| 472 |
+
if isinstance(m, nn.Conv3d):
|
| 473 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 474 |
+
if zero_init_last:
|
| 475 |
+
for m in self.modules():
|
| 476 |
+
if hasattr(m, 'zero_init_last'):
|
| 477 |
+
m.zero_init_last()
|
| 478 |
+
|
| 479 |
+
@torch.jit.ignore
|
| 480 |
+
def group_matcher(self, coarse: bool = False):
|
| 481 |
+
matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
|
| 482 |
+
return matcher
|
| 483 |
+
|
| 484 |
+
@torch.jit.ignore
|
| 485 |
+
def get_classifier(self, name_only: bool = False):
|
| 486 |
+
return 'fc' if name_only else self.fc
|
| 487 |
+
|
| 488 |
+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
| 489 |
+
self.num_classes = num_classes
|
| 490 |
+
if global_pool == 'avg':
|
| 491 |
+
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
| 492 |
+
elif global_pool == 'max':
|
| 493 |
+
self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
|
| 494 |
+
else:
|
| 495 |
+
raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
|
| 496 |
+
|
| 497 |
+
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 498 |
+
|
| 499 |
+
def forward_intermediates(
|
| 500 |
+
self,
|
| 501 |
+
x: torch.Tensor,
|
| 502 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 503 |
+
norm: bool = False,
|
| 504 |
+
stop_early: bool = False,
|
| 505 |
+
output_fmt: str = 'NCHWD',
|
| 506 |
+
intermediates_only: bool = False,
|
| 507 |
+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
| 508 |
+
""" Forward features that returns intermediates.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
x: Input image tensor
|
| 512 |
+
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 513 |
+
norm: Apply norm layer to compatible intermediates
|
| 514 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 515 |
+
output_fmt: Shape of intermediate feature outputs
|
| 516 |
+
intermediates_only: Only return intermediate features
|
| 517 |
+
Returns:
|
| 518 |
+
|
| 519 |
+
"""
|
| 520 |
+
assert output_fmt in ('NCHWD',), 'Output shape must be NCHWD.'
|
| 521 |
+
intermediates = []
|
| 522 |
+
take_indices, max_index = feature_take_indices(5, indices)
|
| 523 |
+
|
| 524 |
+
# forward pass
|
| 525 |
+
feat_idx = 0
|
| 526 |
+
x = self.conv1(x)
|
| 527 |
+
x = self.bn1(x)
|
| 528 |
+
x = self.act1(x)
|
| 529 |
+
if feat_idx in take_indices:
|
| 530 |
+
intermediates.append(x)
|
| 531 |
+
x = self.maxpool(x)
|
| 532 |
+
|
| 533 |
+
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
|
| 534 |
+
if stop_early:
|
| 535 |
+
layer_names = layer_names[:max_index]
|
| 536 |
+
for n in layer_names:
|
| 537 |
+
feat_idx += 1
|
| 538 |
+
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
|
| 539 |
+
if feat_idx in take_indices:
|
| 540 |
+
intermediates.append(x)
|
| 541 |
+
|
| 542 |
+
if intermediates_only:
|
| 543 |
+
return intermediates
|
| 544 |
+
|
| 545 |
+
return x, intermediates
|
| 546 |
+
|
| 547 |
+
def prune_intermediate_layers(
|
| 548 |
+
self,
|
| 549 |
+
indices: Union[int, List[int]] = 1,
|
| 550 |
+
prune_norm: bool = False,
|
| 551 |
+
prune_head: bool = True,
|
| 552 |
+
):
|
| 553 |
+
""" Prune layers not required for specified intermediates.
|
| 554 |
+
"""
|
| 555 |
+
take_indices, max_index = feature_take_indices(5, indices)
|
| 556 |
+
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
|
| 557 |
+
layer_names = layer_names[max_index:]
|
| 558 |
+
for n in layer_names:
|
| 559 |
+
setattr(self, n, nn.Identity())
|
| 560 |
+
if prune_head:
|
| 561 |
+
self.reset_classifier(0, '')
|
| 562 |
+
return take_indices
|
| 563 |
+
|
| 564 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 565 |
+
x = self.conv1(x)
|
| 566 |
+
x = self.bn1(x)
|
| 567 |
+
x = self.act1(x)
|
| 568 |
+
x = self.maxpool(x)
|
| 569 |
+
|
| 570 |
+
x = self.layer1(x)
|
| 571 |
+
x = self.layer2(x)
|
| 572 |
+
x = self.layer3(x)
|
| 573 |
+
x = self.layer4(x)
|
| 574 |
+
return x
|
| 575 |
+
|
| 576 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
| 577 |
+
x = self.global_pool(x)
|
| 578 |
+
x = x.flatten(1)
|
| 579 |
+
if self.drop_rate:
|
| 580 |
+
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
| 581 |
+
return x if pre_logits else self.fc(x)
|
| 582 |
+
|
| 583 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 584 |
+
x = self.forward_features(x)
|
| 585 |
+
x = self.forward_head(x)
|
| 586 |
+
return x
|
| 587 |
+
|
| 588 |
+
@classmethod
|
| 589 |
+
def from_pretrained(
|
| 590 |
+
cls,
|
| 591 |
+
checkpoint_path_or_url: Union[str, os.PathLike],
|
| 592 |
+
verbose: bool = True,
|
| 593 |
+
**kwargs
|
| 594 |
+
) -> 'ResNet':
|
| 595 |
+
"""Load pretrained model weights from a local path or a URL."""
|
| 596 |
+
model = cls(**kwargs)
|
| 597 |
+
|
| 598 |
+
def _is_url(path: str) -> bool:
|
| 599 |
+
try:
|
| 600 |
+
parsed = urlparse(str(path))
|
| 601 |
+
return parsed.scheme in ('http', 'https')
|
| 602 |
+
except Exception:
|
| 603 |
+
return False
|
| 604 |
+
|
| 605 |
+
if _is_url(checkpoint_path_or_url):
|
| 606 |
+
if verbose:
|
| 607 |
+
print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
|
| 608 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 609 |
+
checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
|
| 610 |
+
else:
|
| 611 |
+
local_path = os.fspath(checkpoint_path_or_url)
|
| 612 |
+
if not os.path.exists(local_path):
|
| 613 |
+
raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
|
| 614 |
+
if verbose:
|
| 615 |
+
print(f"Loading checkpoint from local path: {local_path}")
|
| 616 |
+
state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
|
| 617 |
+
|
| 618 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 619 |
+
if verbose:
|
| 620 |
+
print(f"Loaded pretrained weights with msg: {msg}")
|
| 621 |
+
return model
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def resnet18(
|
| 626 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 627 |
+
**kwargs
|
| 628 |
+
) -> ResNet:
|
| 629 |
+
"""ResNet-18 model with 3D operations.
|
| 630 |
+
"""
|
| 631 |
+
kwargs = dict(
|
| 632 |
+
block=BasicBlock,
|
| 633 |
+
layers=[2, 2, 2, 2],
|
| 634 |
+
cardinality=1,
|
| 635 |
+
**kwargs,
|
| 636 |
+
)
|
| 637 |
+
if checkpoint_path_or_url:
|
| 638 |
+
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 639 |
+
return ResNet(**kwargs)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def resnet34(
|
| 643 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 644 |
+
**kwargs
|
| 645 |
+
) -> ResNet:
|
| 646 |
+
"""ResNet-34 model with 3D operations.
|
| 647 |
+
"""
|
| 648 |
+
kwargs = dict(
|
| 649 |
+
block=BasicBlock,
|
| 650 |
+
layers=[3, 4, 6, 3],
|
| 651 |
+
cardinality=1,
|
| 652 |
+
**kwargs,
|
| 653 |
+
)
|
| 654 |
+
if checkpoint_path_or_url:
|
| 655 |
+
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 656 |
+
return ResNet(**kwargs)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def resnet50(
|
| 660 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 661 |
+
**kwargs
|
| 662 |
+
) -> ResNet:
|
| 663 |
+
"""ResNet-50 model with 3D operations.
|
| 664 |
+
"""
|
| 665 |
+
kwargs = dict(
|
| 666 |
+
block=Bottleneck,
|
| 667 |
+
layers=[3, 4, 6, 3],
|
| 668 |
+
cardinality=1,
|
| 669 |
+
**kwargs,
|
| 670 |
+
)
|
| 671 |
+
if checkpoint_path_or_url:
|
| 672 |
+
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 673 |
+
return ResNet(**kwargs)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def resnet101(
|
| 677 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 678 |
+
**kwargs
|
| 679 |
+
) -> ResNet:
|
| 680 |
+
"""ResNet-101 model with 3D operations.
|
| 681 |
+
"""
|
| 682 |
+
kwargs = dict(
|
| 683 |
+
block=Bottleneck,
|
| 684 |
+
layers=[3, 4, 23, 3],
|
| 685 |
+
cardinality=1,
|
| 686 |
+
**kwargs,
|
| 687 |
+
)
|
| 688 |
+
if checkpoint_path_or_url:
|
| 689 |
+
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 690 |
+
return ResNet(**kwargs)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def resnext50(
|
| 694 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 695 |
+
**kwargs
|
| 696 |
+
) -> ResNet:
|
| 697 |
+
"""ResNeXt-50 model with 3D operations.
|
| 698 |
+
"""
|
| 699 |
+
kwargs = dict(
|
| 700 |
+
block=Bottleneck,
|
| 701 |
+
layers=[3, 4, 6, 3],
|
| 702 |
+
cardinality=32,
|
| 703 |
+
base_width=4,
|
| 704 |
+
**kwargs,
|
| 705 |
+
)
|
| 706 |
+
if checkpoint_path_or_url:
|
| 707 |
+
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 708 |
+
return ResNet(**kwargs)
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def resnext101(
|
| 712 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 713 |
+
**kwargs
|
| 714 |
+
) -> ResNet:
|
| 715 |
+
"""ResNeXt-101 model with 3D operations.
|
| 716 |
+
"""
|
| 717 |
+
kwargs = dict(
|
| 718 |
+
block=Bottleneck,
|
| 719 |
+
layers=[3, 4, 23, 3],
|
| 720 |
+
cardinality=32,
|
| 721 |
+
base_width=8,
|
| 722 |
+
**kwargs,
|
| 723 |
+
)
|
| 724 |
+
if checkpoint_path_or_url:
|
| 725 |
+
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 726 |
+
return ResNet(**kwargs)
|
spectre/models/seomt.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/tue-mps/eomt/
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from spectre.models import VisionTransformer
|
| 11 |
+
from spectre.models.layers import LayerNorm3d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ScaleBlock(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
embed_dim: int,
|
| 18 |
+
scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
|
| 19 |
+
conv1_layer: nn.Module = nn.ConvTranspose3d,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.conv1 = conv1_layer(
|
| 24 |
+
embed_dim,
|
| 25 |
+
embed_dim,
|
| 26 |
+
kernel_size=scale_factors,
|
| 27 |
+
stride=scale_factors,
|
| 28 |
+
)
|
| 29 |
+
self.act = nn.GELU()
|
| 30 |
+
self.conv2 = nn.Conv3d(
|
| 31 |
+
embed_dim,
|
| 32 |
+
embed_dim,
|
| 33 |
+
kernel_size=3,
|
| 34 |
+
padding=1,
|
| 35 |
+
groups=embed_dim,
|
| 36 |
+
bias=False,
|
| 37 |
+
)
|
| 38 |
+
self.norm = LayerNorm3d(embed_dim)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
# print(x.shape)
|
| 42 |
+
x = self.conv1(x)
|
| 43 |
+
x = self.act(x)
|
| 44 |
+
x = self.conv2(x)
|
| 45 |
+
x = self.norm(x)
|
| 46 |
+
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def compute_upscale_stages(patch_size, min_size=4):
|
| 51 |
+
# Compute how many times to upscale per dimension
|
| 52 |
+
num_stages = []
|
| 53 |
+
for size in patch_size:
|
| 54 |
+
stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
|
| 55 |
+
num_stages.append(stages)
|
| 56 |
+
return num_stages
|
| 57 |
+
|
| 58 |
+
def voxel_shuffle_3d(x: torch.Tensor, r: int = 2) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Rearranges channels of a 5D tensor (N, C*r^3, D, H, W) to
|
| 61 |
+
(N, C, D*r, H*r, W*r).
|
| 62 |
+
"""
|
| 63 |
+
n, c, d, h, w = x.size()
|
| 64 |
+
assert c % (r ** 3) == 0, f"Channels {c} not divisible by r^3={r**3}"
|
| 65 |
+
out_c = c // (r ** 3)
|
| 66 |
+
x = x.view(n, out_c, r, r, r, d, h, w) # (N, C, r, r, r, D, H, W)
|
| 67 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() # (N, C, D, r, H, r, W, r)
|
| 68 |
+
x = x.view(n, out_c, d * r, h * r, w * r) # (N, C, D*r, H*r, W*r)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MLPUpBlock3D(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
2x upsampling via:
|
| 75 |
+
- 1x1x1 Conv (per-voxel MLP) expanding channels by 2^3
|
| 76 |
+
- 3D voxel shuffle to double D/H/W
|
| 77 |
+
- optional norm + activation
|
| 78 |
+
"""
|
| 79 |
+
def __init__(self, channels: int, norm=nn.Identity, activation=nn.GELU):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.proj = nn.Conv3d(channels, channels * 8, kernel_size=1, bias=True)
|
| 82 |
+
self.norm = norm(channels) if norm is not None else nn.Identity()
|
| 83 |
+
self.act = activation() if activation is not None else nn.Identity()
|
| 84 |
+
|
| 85 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
x = self.proj(x) # (N, C*8, D, H, W)
|
| 87 |
+
x = voxel_shuffle_3d(x, 2) # (N, C, 2D, 2H, 2W)
|
| 88 |
+
x = self.norm(x)
|
| 89 |
+
x = self.act(x)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SmolMLPDecoder3D(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
Simple decoder with two MLP upsampling layers (2x each) for total 4x upsampling.
|
| 96 |
+
"""
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
in_channels: int,
|
| 100 |
+
out_channels: int,
|
| 101 |
+
norm=LayerNorm3d,
|
| 102 |
+
activation=nn.GELU,
|
| 103 |
+
):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.up1 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
|
| 106 |
+
self.up2 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
|
| 107 |
+
self.head = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True)
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
x = self.up1(x) # 2x
|
| 111 |
+
x = self.up2(x) # another 2x => 4x total
|
| 112 |
+
x = self.head(x) # map to desired output channels
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
class SpatialMLPDecoder3D(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
Per-voxel MLP that expands logits from (B,Q,H,W,D) to (B,Q,H*s,W*s,D*s)
|
| 118 |
+
by predicting a learned s^3 block per voxel, then rearranging spatially.
|
| 119 |
+
"""
|
| 120 |
+
def __init__(self, num_classes: int, upscale_factor: int = 4, hidden_mul: int = 4):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.num_classes = num_classes
|
| 123 |
+
self.s = upscale_factor
|
| 124 |
+
hidden = hidden_mul * num_classes
|
| 125 |
+
self.mlp = nn.Sequential(
|
| 126 |
+
nn.Linear(num_classes, hidden),
|
| 127 |
+
nn.GELU(),
|
| 128 |
+
nn.Linear(hidden, num_classes * (upscale_factor ** 3)),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
# x: (B, Q, H, W, D)
|
| 133 |
+
B, Q, H, W, D = x.shape
|
| 134 |
+
s = self.s
|
| 135 |
+
x = x.permute(0, 2, 3, 4, 1).contiguous().view(B * H * W * D, Q) # (BHW D, Q)
|
| 136 |
+
x = self.mlp(x) # (BHW D, Q*s^3)
|
| 137 |
+
x = x.view(B, H, W, D, Q, s, s, s) # (B,H,W,D,Q,s,s,s)
|
| 138 |
+
x = x.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous() # (B,Q,H,s,W,s,D,s)
|
| 139 |
+
x = x.view(B, Q, H * s, W * s, D * s)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class SimpleConvDecoder3D(nn.Module):
|
| 144 |
+
"""
|
| 145 |
+
Depthwise ConvTranspose3d upsampler:
|
| 146 |
+
- Assumes input & output channels == num_classes
|
| 147 |
+
- Uses groups=num_classes for class-wise deconvolution
|
| 148 |
+
"""
|
| 149 |
+
def __init__(self, num_classes: int, upscale_factor: int = 4):
|
| 150 |
+
super().__init__()
|
| 151 |
+
s = upscale_factor
|
| 152 |
+
self.deconv = nn.ConvTranspose3d(
|
| 153 |
+
in_channels=num_classes,
|
| 154 |
+
out_channels=num_classes,
|
| 155 |
+
kernel_size=s,
|
| 156 |
+
stride=s,
|
| 157 |
+
padding=0,
|
| 158 |
+
output_padding=0,
|
| 159 |
+
groups=num_classes,
|
| 160 |
+
bias=True,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
# x: (B, Q, H, W, D)
|
| 165 |
+
return self.deconv(x)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class SEoMT(nn.Module):
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
backbone: VisionTransformer,
|
| 172 |
+
num_classes: int,
|
| 173 |
+
# num_q: int,
|
| 174 |
+
num_blocks=4,
|
| 175 |
+
masked_attn_enabled=True,
|
| 176 |
+
return_only_final_layer=False,
|
| 177 |
+
upscale_output=True,
|
| 178 |
+
for_nnunet=False,
|
| 179 |
+
decoder=False,
|
| 180 |
+
):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.backbone = backbone
|
| 183 |
+
self.num_q = num_classes
|
| 184 |
+
self.num_blocks = num_blocks
|
| 185 |
+
self.masked_attn_enabled = masked_attn_enabled
|
| 186 |
+
self.return_only_final_layer = return_only_final_layer
|
| 187 |
+
self.upscale_output = upscale_output
|
| 188 |
+
self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
|
| 189 |
+
self.for_nnunet = for_nnunet
|
| 190 |
+
self.q = nn.Embedding(num_classes, self.backbone.embed_dim)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
self.mask_head = nn.Sequential(
|
| 194 |
+
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 195 |
+
nn.GELU(),
|
| 196 |
+
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 197 |
+
nn.GELU(),
|
| 198 |
+
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
patch_size = self.backbone.patch_embed.patch_size
|
| 202 |
+
num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
|
| 203 |
+
|
| 204 |
+
# Build per-stage scale factors list
|
| 205 |
+
max_stages = max(num_upscale_stages)
|
| 206 |
+
upscale_blocks = []
|
| 207 |
+
for stage_idx in range(max_stages):
|
| 208 |
+
# for each dimension, upscale by 2 only if this dimension still has
|
| 209 |
+
# remaining upscales at this stage
|
| 210 |
+
scale_factors = tuple(
|
| 211 |
+
2 if stage_idx < num_upscale_stages[dim] else 1
|
| 212 |
+
for dim in range(len(patch_size))
|
| 213 |
+
)
|
| 214 |
+
upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
|
| 215 |
+
scale_factors=scale_factors))
|
| 216 |
+
|
| 217 |
+
self.upscale = nn.Sequential(*upscale_blocks)
|
| 218 |
+
|
| 219 |
+
def _predict(self, x: torch.Tensor, stage: int = None):
|
| 220 |
+
q = x[:, : self.num_q, :]
|
| 221 |
+
# print(stage)
|
| 222 |
+
# class_logits = self.class_head(q)
|
| 223 |
+
x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
|
| 224 |
+
x = x.transpose(1, 2).reshape(
|
| 225 |
+
x.shape[0], -1, *self.backbone.patch_embed.grid_size
|
| 226 |
+
)
|
| 227 |
+
mask_logits = torch.einsum(
|
| 228 |
+
"bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
return mask_logits
|
| 233 |
+
|
| 234 |
+
@torch.compiler.disable
|
| 235 |
+
def _disable_attn_mask(self, attn_mask, prob):
|
| 236 |
+
if prob < 1:
|
| 237 |
+
random_queries = (
|
| 238 |
+
torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
|
| 239 |
+
> prob
|
| 240 |
+
)
|
| 241 |
+
attn_mask[
|
| 242 |
+
:, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
|
| 243 |
+
][random_queries] = True
|
| 244 |
+
|
| 245 |
+
return attn_mask
|
| 246 |
+
|
| 247 |
+
def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope = None):
|
| 248 |
+
B, N, C = x.shape
|
| 249 |
+
|
| 250 |
+
q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
|
| 251 |
+
kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
|
| 252 |
+
k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 253 |
+
q, k = module.q_norm(q), module.k_norm(k)
|
| 254 |
+
|
| 255 |
+
if mask is not None:
|
| 256 |
+
mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
|
| 257 |
+
|
| 258 |
+
dropout_p = module.attn_drop.p if self.training else 0.0
|
| 259 |
+
|
| 260 |
+
if rope is not None:
|
| 261 |
+
if isinstance(rope, list):
|
| 262 |
+
rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
|
| 263 |
+
q, k = module.apply_rotary_pos_emb(q, k, rope)
|
| 264 |
+
|
| 265 |
+
if module.fused_attn:
|
| 266 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
|
| 267 |
+
else:
|
| 268 |
+
attn = (q @ k.transpose(-2, -1)) * module.scale
|
| 269 |
+
if mask is not None:
|
| 270 |
+
attn = attn.masked_fill(~mask, float("-inf"))
|
| 271 |
+
attn = F.softmax(attn, dim=-1)
|
| 272 |
+
attn = module.attn_drop(attn)
|
| 273 |
+
x = attn @ v
|
| 274 |
+
|
| 275 |
+
x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
|
| 276 |
+
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
def forward(self, x: torch.Tensor):
|
| 280 |
+
|
| 281 |
+
if self.for_nnunet: # swap data order, will be incoming at czyx - cxyz
|
| 282 |
+
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
| 283 |
+
|
| 284 |
+
self.backbone.patch_embed.set_input_size(x.shape[2:])
|
| 285 |
+
x = self.backbone.patch_embed(x)
|
| 286 |
+
x, rope = self.backbone._pos_embed(x)
|
| 287 |
+
x = self.backbone.patch_drop(x)
|
| 288 |
+
x = self.backbone.norm_pre(x)
|
| 289 |
+
attn_mask = None
|
| 290 |
+
mask_logits_per_layer = []
|
| 291 |
+
|
| 292 |
+
for i, block in enumerate(self.backbone.blocks):
|
| 293 |
+
if i == len(self.backbone.blocks) - self.num_blocks:
|
| 294 |
+
x = torch.cat(
|
| 295 |
+
(self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if (
|
| 299 |
+
self.masked_attn_enabled
|
| 300 |
+
and i >= len(self.backbone.blocks) - self.num_blocks
|
| 301 |
+
):
|
| 302 |
+
mask_logits = self._predict(self.backbone.norm(x))
|
| 303 |
+
|
| 304 |
+
if self.for_nnunet:
|
| 305 |
+
# swap back to czyx
|
| 306 |
+
|
| 307 |
+
if self.upscale_output:
|
| 308 |
+
# Upscale to original input size / stage
|
| 309 |
+
|
| 310 |
+
stage = len(self.backbone.blocks) -i
|
| 311 |
+
if stage is not None:
|
| 312 |
+
input_size = tuple(
|
| 313 |
+
int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**(stage))
|
| 314 |
+
for dim in range(len(self.backbone.patch_embed.patch_size))
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
|
| 318 |
+
else:
|
| 319 |
+
mask_logits_per_layer.append(mask_logits.permute(0, 1, 4, 2, 3).contiguous())
|
| 320 |
+
else:
|
| 321 |
+
mask_logits_per_layer.append(mask_logits)
|
| 322 |
+
|
| 323 |
+
attn_mask = torch.ones(
|
| 324 |
+
x.shape[0],
|
| 325 |
+
x.shape[1],
|
| 326 |
+
x.shape[1],
|
| 327 |
+
dtype=torch.bool,
|
| 328 |
+
device=x.device,
|
| 329 |
+
)
|
| 330 |
+
interpolated = F.interpolate(
|
| 331 |
+
mask_logits,
|
| 332 |
+
self.backbone.patch_embed.grid_size,
|
| 333 |
+
mode="trilinear",
|
| 334 |
+
)
|
| 335 |
+
interpolated = interpolated.view(
|
| 336 |
+
interpolated.size(0), interpolated.size(1), -1
|
| 337 |
+
)
|
| 338 |
+
attn_mask[
|
| 339 |
+
:,
|
| 340 |
+
: self.num_q,
|
| 341 |
+
self.num_q + self.backbone.num_prefix_tokens :,
|
| 342 |
+
] = (
|
| 343 |
+
interpolated > 0
|
| 344 |
+
)
|
| 345 |
+
attn_mask = self._disable_attn_mask(
|
| 346 |
+
attn_mask,
|
| 347 |
+
self.attn_mask_probs[
|
| 348 |
+
i - len(self.backbone.blocks) + self.num_blocks
|
| 349 |
+
],
|
| 350 |
+
)
|
| 351 |
+
x = x + block.drop_path1(
|
| 352 |
+
block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope=rope))
|
| 353 |
+
)
|
| 354 |
+
x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
|
| 355 |
+
|
| 356 |
+
mask_logits = self._predict(self.backbone.norm(x))
|
| 357 |
+
if self.for_nnunet:
|
| 358 |
+
input_size = tuple(
|
| 359 |
+
int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**0)
|
| 360 |
+
for dim in range(len(self.backbone.patch_embed.patch_size))
|
| 361 |
+
)
|
| 362 |
+
mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
|
| 363 |
+
else:
|
| 364 |
+
mask_logits_per_layer.append(mask_logits)
|
| 365 |
+
|
| 366 |
+
if self.for_nnunet:
|
| 367 |
+
# return in reversed order for deep supervision
|
| 368 |
+
mask_logits_per_layer = mask_logits_per_layer[::-1]
|
| 369 |
+
return mask_logits_per_layer if not self.return_only_final_layer else mask_logits_per_layer[0]
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
from spectre.models import vit_large_patch16_128
|
| 374 |
+
|
| 375 |
+
model = SEoMT(
|
| 376 |
+
backbone=vit_large_patch16_128(pos_embed='rope',
|
| 377 |
+
rope_kwargs={
|
| 378 |
+
"base": 1000.0, # works for most 3D models
|
| 379 |
+
},),
|
| 380 |
+
num_classes=4,
|
| 381 |
+
num_blocks=4,
|
| 382 |
+
masked_attn_enabled=True,
|
| 383 |
+
return_only_final_layer=True,
|
| 384 |
+
for_nnunet=True,
|
| 385 |
+
upscale_output=True,
|
| 386 |
+
decoder=False,
|
| 387 |
+
)
|
| 388 |
+
# print number of parameters
|
| 389 |
+
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
| 390 |
+
|
| 391 |
+
x = torch.randn(2, 1, 64, 128, 128)
|
| 392 |
+
out = model(x)
|
| 393 |
+
for o in out:
|
| 394 |
+
print(o.shape)
|
spectre/models/upsample_anything.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def UPA(hr_image, lr_volume, device="cuda", use_amp=True):
|
| 9 |
+
"""
|
| 10 |
+
hr_image: numpy or torch [C,Hh,Wh,Dh]
|
| 11 |
+
lr_volume: torch [1,C,Hl,Wl,Dl]
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
hr = torch.as_tensor(hr_image).unsqueeze(0).float().to(device)
|
| 15 |
+
|
| 16 |
+
_, _, Hh, Wh, Dh = hr.shape
|
| 17 |
+
_, _, Hl, Wl, Dl = lr_volume.shape
|
| 18 |
+
scale = Hh // Hl
|
| 19 |
+
assert Wh // Wl == scale and Dh // Dl == scale, "Inconsistent scale factors"
|
| 20 |
+
|
| 21 |
+
lr_volume = lr_volume.to(device).float()
|
| 22 |
+
lr = F.interpolate(hr, scale_factor=1/scale, mode="trilinear", align_corners=False)
|
| 23 |
+
|
| 24 |
+
model = LearnablePixelwiseAnisoJBU3D(
|
| 25 |
+
Hl, Wl, Dl, scale=scale
|
| 26 |
+
).to(device)
|
| 27 |
+
|
| 28 |
+
model.train()
|
| 29 |
+
opt = torch.optim.Adam(model.parameters(), lr=1e-1)
|
| 30 |
+
max_steps = 350
|
| 31 |
+
gamma = (1e-9 / 1e-1) ** (1.0 / max_steps)
|
| 32 |
+
scheduler = LambdaLR(opt, lr_lambda=lambda step: gamma ** step)
|
| 33 |
+
scaler = torch.amp.GradScaler(device=device, enabled=use_amp)
|
| 34 |
+
|
| 35 |
+
for step in range(max_steps):
|
| 36 |
+
opt.zero_grad(set_to_none=True)
|
| 37 |
+
with torch.amp.autocast(device_type=device, enabled=use_amp):
|
| 38 |
+
pred = model(lr, hr)
|
| 39 |
+
loss = F.l1_loss(pred, hr)
|
| 40 |
+
|
| 41 |
+
scaler.scale(loss).backward()
|
| 42 |
+
scaler.step(opt)
|
| 43 |
+
scaler.update()
|
| 44 |
+
scheduler.step()
|
| 45 |
+
|
| 46 |
+
if step % 50 == 0:
|
| 47 |
+
print(f"step {step}: loss={loss.item():.5f}")
|
| 48 |
+
|
| 49 |
+
model.eval()
|
| 50 |
+
with torch.inference_mode(), \
|
| 51 |
+
torch.amp.autocast(device_type=device, enabled=use_amp, dtype=torch.float16):
|
| 52 |
+
out = model(lr_volume, hr)
|
| 53 |
+
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def _build_offsets_3d(R_max: int, device):
|
| 59 |
+
offs = torch.arange(-R_max, R_max + 1, device=device)
|
| 60 |
+
dX, dY, dZ = torch.meshgrid(offs, offs, offs, indexing="ij")
|
| 61 |
+
return (
|
| 62 |
+
dX.reshape(-1),
|
| 63 |
+
dY.reshape(-1),
|
| 64 |
+
dZ.reshape(-1),
|
| 65 |
+
) # [K]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def gather_lr_scalar_3d(map_lr, Ui, Vi, Wi):
|
| 69 |
+
"""
|
| 70 |
+
map_lr: [1,1,Hl,Wl,Dl] or [Hl,Wl,Dl]
|
| 71 |
+
Ui,Vi,Wi: [Bn,Hh,Wh,Dh]
|
| 72 |
+
"""
|
| 73 |
+
Hl, Wl, Dl = map_lr.shape[-3:]
|
| 74 |
+
flat = Hl * Wl * Dl
|
| 75 |
+
idx = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
|
| 76 |
+
t = map_lr.view(flat)
|
| 77 |
+
vals = t.index_select(0, idx)
|
| 78 |
+
return vals.view(Ui.shape)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def gs_jbu_aniso_noparent_3d(
|
| 82 |
+
feat_lr, # [1,C,Hl,Wl,Dl]
|
| 83 |
+
guide_hr, # [1,G,Hh,Wh,Dh]
|
| 84 |
+
scale,
|
| 85 |
+
sigma_x_map,
|
| 86 |
+
sigma_y_map,
|
| 87 |
+
sigma_z_map,
|
| 88 |
+
sigma_r_map,
|
| 89 |
+
R_max=3,
|
| 90 |
+
alpha_dyn=2.0,
|
| 91 |
+
C_chunk=64,
|
| 92 |
+
Nn_chunk=125,
|
| 93 |
+
):
|
| 94 |
+
_, C, Hl, Wl, Dl = feat_lr.shape
|
| 95 |
+
_, _, Hh, Wh, Dh = guide_hr.shape
|
| 96 |
+
device = feat_lr.device
|
| 97 |
+
dtype_feat = feat_lr.dtype
|
| 98 |
+
|
| 99 |
+
# HR grid
|
| 100 |
+
x = torch.arange(Hh, device=device, dtype=torch.float32)
|
| 101 |
+
y = torch.arange(Wh, device=device, dtype=torch.float32)
|
| 102 |
+
z = torch.arange(Dh, device=device, dtype=torch.float32)
|
| 103 |
+
X, Y, Z = torch.meshgrid(x, y, z, indexing="ij")
|
| 104 |
+
|
| 105 |
+
u = (X + 0.5) / scale - 0.5
|
| 106 |
+
v = (Y + 0.5) / scale - 0.5
|
| 107 |
+
w = (Z + 0.5) / scale - 0.5
|
| 108 |
+
|
| 109 |
+
uc = torch.round(u).clamp(0, Hl - 1).long()
|
| 110 |
+
vc = torch.round(v).clamp(0, Wl - 1).long()
|
| 111 |
+
wc = torch.round(w).clamp(0, Dl - 1).long()
|
| 112 |
+
|
| 113 |
+
# Dynamic radius
|
| 114 |
+
sigma_eff = torch.maximum(
|
| 115 |
+
sigma_x_map,
|
| 116 |
+
torch.maximum(sigma_y_map, sigma_z_map),
|
| 117 |
+
)
|
| 118 |
+
sigma_eff_hr = F.interpolate(
|
| 119 |
+
sigma_eff, (Hh, Wh, Dh), mode="trilinear", align_corners=False
|
| 120 |
+
)
|
| 121 |
+
# sigma_eff_hr = sigma_eff_hr.squeeze(0).squeeze(0)
|
| 122 |
+
R_map = torch.ceil(alpha_dyn * sigma_eff_hr).clamp(1, R_max).long()
|
| 123 |
+
|
| 124 |
+
dX_all, dY_all, dZ_all = _build_offsets_3d(R_max, device)
|
| 125 |
+
|
| 126 |
+
num = torch.zeros(C, Hh, Wh, Dh, device=device, dtype=torch.float32)
|
| 127 |
+
den = torch.zeros(Hh, Wh, Dh, device=device, dtype=torch.float32)
|
| 128 |
+
m = torch.full((Hh, Wh, Dh), -1e9, device=device, dtype=torch.float32)
|
| 129 |
+
|
| 130 |
+
feat_flat = feat_lr[0].permute(1, 2, 3, 0).reshape(-1, C)
|
| 131 |
+
guide_lr = F.interpolate(
|
| 132 |
+
guide_hr, (Hl, Wl, Dl), mode="trilinear", align_corners=False
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
for n0 in range(0, len(dX_all), Nn_chunk):
|
| 136 |
+
dX = dX_all[n0:n0+Nn_chunk][:, None, None, None]
|
| 137 |
+
dY = dY_all[n0:n0+Nn_chunk][:, None, None, None]
|
| 138 |
+
dZ = dZ_all[n0:n0+Nn_chunk][:, None, None, None]
|
| 139 |
+
|
| 140 |
+
Ui = torch.clamp(uc.unsqueeze(0) + dX, 0, Hl - 1)
|
| 141 |
+
Vi = torch.clamp(vc.unsqueeze(0) + dY, 0, Wl - 1)
|
| 142 |
+
Wi = torch.clamp(wc.unsqueeze(0) + dZ, 0, Dl - 1)
|
| 143 |
+
|
| 144 |
+
# mask = (dX**2 + dY**2 + dZ**2 <= R_map[None, ...] ** 2)
|
| 145 |
+
mask = (dX**2 + dY**2 + dZ**2 <= R_map**2).squeeze(0).squeeze(0)
|
| 146 |
+
|
| 147 |
+
cx = (Ui.float() + 0.5) * scale - 0.5
|
| 148 |
+
cy = (Vi.float() + 0.5) * scale - 0.5
|
| 149 |
+
cz = (Wi.float() + 0.5) * scale - 0.5
|
| 150 |
+
|
| 151 |
+
dx = X.unsqueeze(0) - cx
|
| 152 |
+
dy = Y.unsqueeze(0) - cy
|
| 153 |
+
dz = Z.unsqueeze(0) - cz
|
| 154 |
+
|
| 155 |
+
sx = gather_lr_scalar_3d(sigma_x_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 156 |
+
sy = gather_lr_scalar_3d(sigma_y_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 157 |
+
sz = gather_lr_scalar_3d(sigma_z_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 158 |
+
sr = gather_lr_scalar_3d(sigma_r_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 159 |
+
|
| 160 |
+
log_ws = (
|
| 161 |
+
-(dx**2)/(2*sx**2)
|
| 162 |
+
-(dy**2)/(2*sy**2)
|
| 163 |
+
-(dz**2)/(2*sz**2)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
diff2 = 0.0
|
| 167 |
+
for g in range(guide_hr.shape[1]):
|
| 168 |
+
g0 = gather_lr_scalar_3d(guide_lr[0, g], Ui, Vi, Wi)
|
| 169 |
+
diff2 += (guide_hr[0, g] - g0) ** 2
|
| 170 |
+
|
| 171 |
+
log_wr = -diff2 / (2 * sr**2 + 1e-8)
|
| 172 |
+
log_w = torch.where(mask, log_ws + log_wr, -1e9)
|
| 173 |
+
|
| 174 |
+
m_chunk = log_w.max(dim=0).values
|
| 175 |
+
m_new = torch.maximum(m, m_chunk)
|
| 176 |
+
|
| 177 |
+
scale_old = torch.exp(m - m_new)
|
| 178 |
+
num *= scale_old
|
| 179 |
+
den *= scale_old
|
| 180 |
+
|
| 181 |
+
w = torch.exp(log_w - m_new)
|
| 182 |
+
den += w.sum(0)
|
| 183 |
+
|
| 184 |
+
idx_flat = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
|
| 185 |
+
|
| 186 |
+
for c0 in range(0, C, C_chunk):
|
| 187 |
+
c1 = min(c0 + C_chunk, C)
|
| 188 |
+
f = feat_flat.index_select(0, idx_flat)[:, c0:c1]
|
| 189 |
+
f = f.view(w.shape + (c1 - c0,))
|
| 190 |
+
num[c0:c1] += (f * w[..., None]).sum(0).permute(3, 0, 1, 2)
|
| 191 |
+
|
| 192 |
+
m = m_new
|
| 193 |
+
|
| 194 |
+
out = (num / den.clamp_min(1e-8)).unsqueeze(0)
|
| 195 |
+
return out.to(dtype_feat)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class LearnablePixelwiseAnisoJBU3D(nn.Module):
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
Hl,
|
| 202 |
+
Wl,
|
| 203 |
+
Dl,
|
| 204 |
+
scale,
|
| 205 |
+
init_sigma=1.5,
|
| 206 |
+
init_sigma_r=0.1,
|
| 207 |
+
R_max=3,
|
| 208 |
+
alpha_dyn=2.0,
|
| 209 |
+
):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.scale = scale
|
| 212 |
+
self.R_max = R_max
|
| 213 |
+
self.alpha_dyn = alpha_dyn
|
| 214 |
+
|
| 215 |
+
self.sx_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
|
| 216 |
+
self.sy_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
|
| 217 |
+
self.sz_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
|
| 218 |
+
self.sr_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma_r)))
|
| 219 |
+
|
| 220 |
+
def forward(self, feat_lr, guide_hr):
|
| 221 |
+
return gs_jbu_aniso_noparent_3d(
|
| 222 |
+
feat_lr,
|
| 223 |
+
guide_hr,
|
| 224 |
+
self.scale,
|
| 225 |
+
torch.exp(self.sx_raw),
|
| 226 |
+
torch.exp(self.sy_raw),
|
| 227 |
+
torch.exp(self.sz_raw),
|
| 228 |
+
torch.exp(self.sr_raw),
|
| 229 |
+
R_max=self.R_max,
|
| 230 |
+
alpha_dyn=self.alpha_dyn,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
import argparse
|
| 236 |
+
|
| 237 |
+
import numpy as np
|
| 238 |
+
import nibabel as nib
|
| 239 |
+
import monai.transforms as transforms
|
| 240 |
+
|
| 241 |
+
parser = argparse.ArgumentParser()
|
| 242 |
+
parser.add_argument("--image_path", type=str, required=True)
|
| 243 |
+
parser.add_argument("--mask_path", type=str, required=True)
|
| 244 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 245 |
+
parser.add_argument("--use_amp", action="store_true")
|
| 246 |
+
args = parser.parse_args()
|
| 247 |
+
|
| 248 |
+
transform = transforms.Compose([
|
| 249 |
+
transforms.LoadImaged(keys=("image", "mask")),
|
| 250 |
+
transforms.EnsureChannelFirstd(keys=("image", "mask"), channel_dim="no_channel"),
|
| 251 |
+
transforms.ScaleIntensityRanged(
|
| 252 |
+
keys=("image",),
|
| 253 |
+
a_min=-150,
|
| 254 |
+
a_max=250,
|
| 255 |
+
b_min=0.0,
|
| 256 |
+
b_max=1.0,
|
| 257 |
+
clip=True,
|
| 258 |
+
),
|
| 259 |
+
transforms.Orientationd(keys=("image", "mask"), axcodes="RAS"),
|
| 260 |
+
transforms.RandWeightedCropd(
|
| 261 |
+
keys=("image", "mask"),
|
| 262 |
+
w_key="mask",
|
| 263 |
+
spatial_size=(128, 128, 64),
|
| 264 |
+
num_samples=1,
|
| 265 |
+
),
|
| 266 |
+
transforms.CopyItemsd(keys=("mask"), times=1, names=("mask_low_res")),
|
| 267 |
+
transforms.Resized(keys=("mask_low_res"), spatial_size=(16, 16, 8), mode="nearest", align_corners=False)
|
| 268 |
+
])
|
| 269 |
+
sample = transform({
|
| 270 |
+
"image": args.image_path,
|
| 271 |
+
"mask": args.mask_path,
|
| 272 |
+
})[0]
|
| 273 |
+
|
| 274 |
+
nib.save(
|
| 275 |
+
nib.Nifti1Image(
|
| 276 |
+
(F.interpolate(sample["mask_low_res"].unsqueeze(0), size=(128, 128, 64), mode="nearest").squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8)),
|
| 277 |
+
affine=np.eye(4),
|
| 278 |
+
),
|
| 279 |
+
"mask_low_res_upscaled.nii.gz",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
sample["mask_low_res"] = F.one_hot(
|
| 283 |
+
sample["mask_low_res"].long().squeeze(0), num_classes=4,
|
| 284 |
+
).permute(3, 0, 1, 2).unsqueeze(0).float()
|
| 285 |
+
|
| 286 |
+
print(sample["mask_low_res"].shape)
|
| 287 |
+
|
| 288 |
+
mask_out = UPA(
|
| 289 |
+
sample["image"],
|
| 290 |
+
sample["mask_low_res"],
|
| 291 |
+
device=args.device,
|
| 292 |
+
use_amp=args.use_amp,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
mask_out = mask_out.argmax(dim=1, keepdim=True)
|
| 296 |
+
|
| 297 |
+
nib.save(
|
| 298 |
+
nib.Nifti1Image(
|
| 299 |
+
(sample["image"] * 255).squeeze(0).cpu().numpy().astype(np.uint8),
|
| 300 |
+
affine=np.eye(4),
|
| 301 |
+
),
|
| 302 |
+
"image.nii.gz",
|
| 303 |
+
)
|
| 304 |
+
nib.save(
|
| 305 |
+
nib.Nifti1Image(
|
| 306 |
+
sample["mask"].squeeze(0).cpu().numpy().astype(np.uint8),
|
| 307 |
+
affine=np.eye(4),
|
| 308 |
+
),
|
| 309 |
+
"mask.nii.gz",
|
| 310 |
+
)
|
| 311 |
+
torch.save(mask_out.squeeze(0).squeeze(0).cpu(), "upsampled_mask.pt")
|
| 312 |
+
nib.save(
|
| 313 |
+
nib.Nifti1Image(
|
| 314 |
+
mask_out.squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8),
|
| 315 |
+
affine=np.eye(4),
|
| 316 |
+
),
|
| 317 |
+
"upsampled_mask.nii.gz",
|
| 318 |
+
)
|
| 319 |
+
|
spectre/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import partial
|
| 3 |
+
from urllib.parse import urlparse
|
| 4 |
+
from typing import (
|
| 5 |
+
Tuple, Union, Callable, Literal,
|
| 6 |
+
Optional, Type, Set, List, Dict, Any,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from timm.layers import PatchDropout, AttentionPoolLatent
|
| 12 |
+
from timm.models.vision_transformer import LayerScale, DropPath, Mlp
|
| 13 |
+
from huggingface_hub import hf_hub_download, load_state_dict_from_file
|
| 14 |
+
|
| 15 |
+
from spectre.models.layers import (
|
| 16 |
+
PatchEmbed,
|
| 17 |
+
Attention,
|
| 18 |
+
RotaryPositionEmbedding,
|
| 19 |
+
)
|
| 20 |
+
from spectre.utils import (
|
| 21 |
+
resample_abs_pos_embed,
|
| 22 |
+
feature_take_indices,
|
| 23 |
+
global_pool_nlc,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Block(nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dim: int,
|
| 31 |
+
num_heads: int,
|
| 32 |
+
attn_mode: str = 'mha',
|
| 33 |
+
q_proj_dim: Optional[int] = None,
|
| 34 |
+
kv_proj_dim: Optional[int] = None,
|
| 35 |
+
mlp_ratio: float = 4.,
|
| 36 |
+
qkv_bias: bool = False,
|
| 37 |
+
qk_norm: bool = False,
|
| 38 |
+
proj_bias: bool = True,
|
| 39 |
+
proj_drop: float = 0.,
|
| 40 |
+
attn_drop: float = 0.,
|
| 41 |
+
init_values: Optional[float] = None,
|
| 42 |
+
drop_path: float = 0.,
|
| 43 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
| 44 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
| 45 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.norm1 = norm_layer(dim)
|
| 49 |
+
self.attn = Attention(
|
| 50 |
+
dim,
|
| 51 |
+
num_heads=num_heads,
|
| 52 |
+
mode=attn_mode,
|
| 53 |
+
q_proj_dim=q_proj_dim,
|
| 54 |
+
kv_proj_dim=kv_proj_dim,
|
| 55 |
+
qkv_bias=qkv_bias,
|
| 56 |
+
qk_norm=qk_norm,
|
| 57 |
+
proj_bias=proj_bias,
|
| 58 |
+
attn_drop=attn_drop,
|
| 59 |
+
proj_drop=proj_drop,
|
| 60 |
+
norm_layer=norm_layer,
|
| 61 |
+
)
|
| 62 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 63 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 64 |
+
|
| 65 |
+
self.norm2 = norm_layer(dim)
|
| 66 |
+
self.mlp = mlp_layer(
|
| 67 |
+
in_features=dim,
|
| 68 |
+
hidden_features=int(dim * mlp_ratio),
|
| 69 |
+
act_layer=act_layer,
|
| 70 |
+
bias=proj_bias,
|
| 71 |
+
drop=proj_drop,
|
| 72 |
+
)
|
| 73 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 74 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 75 |
+
|
| 76 |
+
def forward(
|
| 77 |
+
self,
|
| 78 |
+
x: torch.Tensor,
|
| 79 |
+
rope = None
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rope=rope)))
|
| 82 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class VisionTransformer(nn.Module):
|
| 87 |
+
""" Vision Transformer with 3D Patch Embedding
|
| 88 |
+
"""
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
img_size: Union[int, Tuple[int, int, int]] = (128, 128, 64),
|
| 92 |
+
patch_size: Union[int, Tuple[int, int, int]] = (16, 16, 8),
|
| 93 |
+
in_chans: int = 1,
|
| 94 |
+
num_classes: int = 1000,
|
| 95 |
+
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
|
| 96 |
+
embed_dim: int = 768,
|
| 97 |
+
depth: int = 12,
|
| 98 |
+
num_heads: int = 12,
|
| 99 |
+
attn_mode: str = 'mha',
|
| 100 |
+
q_proj_dim: Optional[int] = None,
|
| 101 |
+
kv_proj_dim: Optional[int] = None,
|
| 102 |
+
mlp_ratio: float = 4.,
|
| 103 |
+
qkv_bias: bool = True,
|
| 104 |
+
qk_norm: bool = False,
|
| 105 |
+
proj_bias: bool = True,
|
| 106 |
+
init_values: Optional[float] = None,
|
| 107 |
+
class_token: bool = True,
|
| 108 |
+
pos_embed: str = 'learn',
|
| 109 |
+
no_embed_class: bool = False,
|
| 110 |
+
rope_kwargs: Optional[dict] = None,
|
| 111 |
+
reg_tokens: int = 0,
|
| 112 |
+
pre_norm: bool = False,
|
| 113 |
+
final_norm: bool = True,
|
| 114 |
+
fc_norm: Optional[bool] = None,
|
| 115 |
+
dynamic_img_size: bool = False,
|
| 116 |
+
dynamic_img_pad: bool = False,
|
| 117 |
+
drop_rate: float = 0.,
|
| 118 |
+
pos_drop_rate: float = 0.,
|
| 119 |
+
patch_drop_rate: float = 0.,
|
| 120 |
+
proj_drop_rate: float = 0.,
|
| 121 |
+
attn_drop_rate: float = 0.,
|
| 122 |
+
drop_path_rate: float = 0.,
|
| 123 |
+
embed_layer: Callable = PatchEmbed,
|
| 124 |
+
embed_norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
|
| 125 |
+
norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
|
| 126 |
+
act_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
|
| 127 |
+
block_fn: Type[nn.Module] = Block,
|
| 128 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
| 129 |
+
) -> None:
|
| 130 |
+
"""
|
| 131 |
+
Args:
|
| 132 |
+
img_size: Input image size.
|
| 133 |
+
patch_size: Patch size.
|
| 134 |
+
in_chans: Number of image input channels.
|
| 135 |
+
num_classes: Number of classes for classification head.
|
| 136 |
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
| 137 |
+
embed_dim: Transformer embedding dimension.
|
| 138 |
+
depth: Depth of transformer.
|
| 139 |
+
num_heads: Number of attention heads.
|
| 140 |
+
attn_mode: Attention mode ('mha', 'mqa', 'mla').
|
| 141 |
+
q_proj_dim: Query projection dimension for 'mla' mode.
|
| 142 |
+
kv_proj_dim: Key, value projection dimension for 'mla' mode.
|
| 143 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
| 144 |
+
qkv_bias: Enable bias for qkv projections if True.
|
| 145 |
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
| 146 |
+
class_token: Use class token.
|
| 147 |
+
pos_embed: Type of position embedding to use (default: 'learn').
|
| 148 |
+
no_embed_class: Don't include position embeddings for class (or reg) tokens for learnable pos_embed.
|
| 149 |
+
rope_kwargs: Additional arguments for rotary position embedding.
|
| 150 |
+
reg_tokens: Number of register tokens.
|
| 151 |
+
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
|
| 152 |
+
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
|
| 153 |
+
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
| 154 |
+
drop_rate: Head dropout rate.
|
| 155 |
+
pos_drop_rate: Position embedding dropout rate.
|
| 156 |
+
attn_drop_rate: Attention dropout rate.
|
| 157 |
+
drop_path_rate: Stochastic depth rate.
|
| 158 |
+
weight_init: Weight initialization scheme.
|
| 159 |
+
fix_init: Apply weight initialization fix (scaling w/ layer index).
|
| 160 |
+
embed_layer: Patch embedding layer.
|
| 161 |
+
embed_norm_layer: Normalization layer to use / override in patch embed module.
|
| 162 |
+
norm_layer: Normalization layer.
|
| 163 |
+
act_layer: MLP activation layer.
|
| 164 |
+
block_fn: Transformer block layer.
|
| 165 |
+
"""
|
| 166 |
+
super().__init__()
|
| 167 |
+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
| 168 |
+
assert class_token or global_pool != 'token'
|
| 169 |
+
assert pos_embed in ('', 'none', 'learn', 'rope')
|
| 170 |
+
assert attn_mode in ('mha', 'mqa', 'mla')
|
| 171 |
+
rope_kwargs = {} if rope_kwargs is None else dict(rope_kwargs)
|
| 172 |
+
rope_kwargs.setdefault("dtype", torch.float32) # robust with mixed-precision
|
| 173 |
+
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
|
| 174 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 175 |
+
embed_norm_layer = embed_norm_layer
|
| 176 |
+
act_layer = act_layer or nn.GELU
|
| 177 |
+
|
| 178 |
+
self.num_classes = num_classes
|
| 179 |
+
self.global_pool = global_pool
|
| 180 |
+
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
|
| 181 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
| 182 |
+
self.num_prefix_tokens += reg_tokens
|
| 183 |
+
self.num_reg_tokens = reg_tokens
|
| 184 |
+
self.has_class_token = class_token
|
| 185 |
+
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
| 186 |
+
self.dynamic_img_size = dynamic_img_size
|
| 187 |
+
|
| 188 |
+
embed_args = {}
|
| 189 |
+
if self.dynamic_img_size:
|
| 190 |
+
# flatten deferred until after pos embed
|
| 191 |
+
embed_args.update(dict(strict_img_size=False, output_fmt="NHWDC"))
|
| 192 |
+
elif pos_embed == 'rope':
|
| 193 |
+
embed_args['output_fmt'] = "NHWDC"
|
| 194 |
+
if embed_norm_layer is not None:
|
| 195 |
+
embed_args['norm_layer'] = embed_norm_layer
|
| 196 |
+
self.patch_embed = embed_layer(
|
| 197 |
+
img_size=img_size,
|
| 198 |
+
patch_size=patch_size,
|
| 199 |
+
in_chans=in_chans,
|
| 200 |
+
embed_dim=embed_dim,
|
| 201 |
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
| 202 |
+
dynamic_img_pad=dynamic_img_pad,
|
| 203 |
+
**embed_args,
|
| 204 |
+
)
|
| 205 |
+
num_patches = self.patch_embed.num_patches
|
| 206 |
+
reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
| 207 |
+
|
| 208 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
| 209 |
+
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
| 210 |
+
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
| 211 |
+
self.pos_embed, self.rope, self.requires_per_sample_rope = None, None, False
|
| 212 |
+
if pos_embed == 'learn':
|
| 213 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
| 214 |
+
if pos_embed == 'rope':
|
| 215 |
+
self.rope = RotaryPositionEmbedding(
|
| 216 |
+
embed_dim=embed_dim,
|
| 217 |
+
num_heads=num_heads,
|
| 218 |
+
**rope_kwargs,
|
| 219 |
+
)
|
| 220 |
+
self.requires_per_sample_rope = any([
|
| 221 |
+
self.rope.shift_coords is not None,
|
| 222 |
+
self.rope.jitter_coords is not None,
|
| 223 |
+
self.rope.rescale_coords is not None,
|
| 224 |
+
])
|
| 225 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
| 226 |
+
if patch_drop_rate > 0:
|
| 227 |
+
self.patch_drop = PatchDropout(
|
| 228 |
+
patch_drop_rate,
|
| 229 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
self.patch_drop = nn.Identity()
|
| 233 |
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
| 234 |
+
|
| 235 |
+
dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule
|
| 236 |
+
self.blocks = nn.Sequential(*[
|
| 237 |
+
block_fn(
|
| 238 |
+
dim=embed_dim,
|
| 239 |
+
num_heads=num_heads,
|
| 240 |
+
attn_mode=attn_mode,
|
| 241 |
+
q_proj_dim=q_proj_dim,
|
| 242 |
+
kv_proj_dim=kv_proj_dim,
|
| 243 |
+
mlp_ratio=mlp_ratio,
|
| 244 |
+
qkv_bias=qkv_bias,
|
| 245 |
+
qk_norm=qk_norm,
|
| 246 |
+
proj_bias=proj_bias,
|
| 247 |
+
init_values=init_values,
|
| 248 |
+
proj_drop=proj_drop_rate,
|
| 249 |
+
attn_drop=attn_drop_rate,
|
| 250 |
+
drop_path=dpr[i],
|
| 251 |
+
norm_layer=norm_layer,
|
| 252 |
+
act_layer=act_layer,
|
| 253 |
+
mlp_layer=mlp_layer,
|
| 254 |
+
)
|
| 255 |
+
for i in range(depth)])
|
| 256 |
+
self.feature_info = [
|
| 257 |
+
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
|
| 258 |
+
self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
|
| 259 |
+
|
| 260 |
+
# Classifier Head
|
| 261 |
+
if global_pool == 'map':
|
| 262 |
+
self.attn_pool = AttentionPoolLatent(
|
| 263 |
+
self.embed_dim,
|
| 264 |
+
num_heads=num_heads,
|
| 265 |
+
mlp_ratio=mlp_ratio,
|
| 266 |
+
norm_layer=norm_layer,
|
| 267 |
+
act_layer=act_layer,
|
| 268 |
+
)
|
| 269 |
+
else:
|
| 270 |
+
self.attn_pool = None
|
| 271 |
+
self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
|
| 272 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 273 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 274 |
+
|
| 275 |
+
self.init_weights()
|
| 276 |
+
|
| 277 |
+
def init_weights(self) -> None:
|
| 278 |
+
if self.pos_embed is not None and not self.pos_embed.is_meta:
|
| 279 |
+
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
| 280 |
+
if self.cls_token is not None and not self.cls_token.is_meta:
|
| 281 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 282 |
+
if self.reg_token is not None and not self.reg_token.is_meta:
|
| 283 |
+
nn.init.normal_(self.reg_token, std=1e-6)
|
| 284 |
+
self.apply(self._init_weights)
|
| 285 |
+
|
| 286 |
+
def _init_weights(self, m: nn.Module) -> None:
|
| 287 |
+
# this fn left here for compat with downstream users
|
| 288 |
+
if isinstance(m, nn.Linear):
|
| 289 |
+
if not m.weight.is_meta:
|
| 290 |
+
nn.init.trunc_normal_(m.weight, std=.02)
|
| 291 |
+
if m.bias is not None and not m.bias.is_meta:
|
| 292 |
+
nn.init.zeros_(m.bias)
|
| 293 |
+
|
| 294 |
+
@torch.jit.ignore
|
| 295 |
+
def no_weight_decay(self) -> Set:
|
| 296 |
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
| 297 |
+
|
| 298 |
+
@torch.jit.ignore
|
| 299 |
+
def get_classifier(self) -> nn.Module:
|
| 300 |
+
return self.head
|
| 301 |
+
|
| 302 |
+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
| 303 |
+
self.num_classes = num_classes
|
| 304 |
+
if global_pool is not None:
|
| 305 |
+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
| 306 |
+
if global_pool == 'map' and self.attn_pool is None:
|
| 307 |
+
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
| 308 |
+
elif global_pool != 'map' and self.attn_pool is not None:
|
| 309 |
+
self.attn_pool = None # remove attention pooling
|
| 310 |
+
self.global_pool = global_pool
|
| 311 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 312 |
+
|
| 313 |
+
def set_input_size(
|
| 314 |
+
self,
|
| 315 |
+
img_size: Optional[Tuple[int, int, int]] = None,
|
| 316 |
+
patch_size: Optional[Tuple[int, int, int]] = None,
|
| 317 |
+
):
|
| 318 |
+
"""Method updates the input image resolution, patch size
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
img_size: New input resolution, if None current resolution is used
|
| 322 |
+
patch_size: New patch size, if None existing patch size is used
|
| 323 |
+
"""
|
| 324 |
+
prev_grid_size = self.patch_embed.grid_size
|
| 325 |
+
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
|
| 326 |
+
if self.pos_embed is not None:
|
| 327 |
+
num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
|
| 328 |
+
num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
|
| 329 |
+
if num_new_tokens != self.pos_embed.shape[1]:
|
| 330 |
+
self.pos_embed = nn.Parameter(resample_abs_pos_embed(
|
| 331 |
+
self.pos_embed,
|
| 332 |
+
new_size=self.patch_embed.grid_size,
|
| 333 |
+
old_size=prev_grid_size,
|
| 334 |
+
num_prefix_tokens=num_prefix_tokens,
|
| 335 |
+
verbose=True,
|
| 336 |
+
))
|
| 337 |
+
|
| 338 |
+
def _pos_embed(self, x: torch.Tensor):
|
| 339 |
+
if self.pos_embed is None and self.rope is None:
|
| 340 |
+
x = x.view(x.shape[0], -1, x.shape[-1])
|
| 341 |
+
if self.reg_token is not None:
|
| 342 |
+
x = torch.cat([self.reg_token.expand(x.shape[0], -1, -1), x], dim=1)
|
| 343 |
+
if self.cls_token is not None:
|
| 344 |
+
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
|
| 345 |
+
return x, None
|
| 346 |
+
|
| 347 |
+
if self.dynamic_img_size or self.rope is not None:
|
| 348 |
+
B, H, W, D, C = x.shape
|
| 349 |
+
x = x.view(B, -1, C)
|
| 350 |
+
|
| 351 |
+
pos_embed, rope = None, None
|
| 352 |
+
if self.pos_embed is not None:
|
| 353 |
+
if self.dynamic_img_size:
|
| 354 |
+
prev_grid_size = self.patch_embed.grid_size
|
| 355 |
+
pos_embed = resample_abs_pos_embed(
|
| 356 |
+
self.pos_embed,
|
| 357 |
+
new_size=(H, W, D),
|
| 358 |
+
old_size=prev_grid_size,
|
| 359 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
| 360 |
+
)
|
| 361 |
+
else:
|
| 362 |
+
pos_embed = self.pos_embed
|
| 363 |
+
|
| 364 |
+
if self.rope is not None:
|
| 365 |
+
if self.requires_per_sample_rope:
|
| 366 |
+
rope = [self.rope(H=H, W=W, D=D) for _ in range(B)]
|
| 367 |
+
else:
|
| 368 |
+
rope = self.rope(H=H, W=W, D=D)
|
| 369 |
+
|
| 370 |
+
to_cat = []
|
| 371 |
+
if self.cls_token is not None:
|
| 372 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
| 373 |
+
if self.reg_token is not None:
|
| 374 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
| 375 |
+
|
| 376 |
+
if self.no_embed_class:
|
| 377 |
+
# deit-3, updated JAX (big vision)
|
| 378 |
+
# position embedding does not overlap with class token, add then concat
|
| 379 |
+
if pos_embed is not None:
|
| 380 |
+
x = x + pos_embed
|
| 381 |
+
if to_cat:
|
| 382 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 383 |
+
else:
|
| 384 |
+
# original timm, JAX, and deit vit impl
|
| 385 |
+
# pos_embed has entry for class token, concat then add
|
| 386 |
+
if to_cat:
|
| 387 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 388 |
+
if pos_embed is not None:
|
| 389 |
+
x = x + pos_embed
|
| 390 |
+
|
| 391 |
+
return self.pos_drop(x), rope
|
| 392 |
+
|
| 393 |
+
def forward_intermediates(
|
| 394 |
+
self,
|
| 395 |
+
x: torch.Tensor,
|
| 396 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 397 |
+
return_prefix_tokens: bool = False,
|
| 398 |
+
norm: bool = False,
|
| 399 |
+
stop_early: bool = False,
|
| 400 |
+
output_fmt: str = 'NCHWD',
|
| 401 |
+
intermediates_only: bool = False,
|
| 402 |
+
output_dict: bool = False,
|
| 403 |
+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
|
| 404 |
+
""" Forward features that returns intermediates.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
x: Input image tensor
|
| 408 |
+
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 409 |
+
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
| 410 |
+
norm: Apply norm layer to all intermediates
|
| 411 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 412 |
+
output_fmt: Shape of intermediate feature outputs
|
| 413 |
+
intermediates_only: Only return intermediate features
|
| 414 |
+
output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
|
| 415 |
+
Returns:
|
| 416 |
+
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
|
| 417 |
+
'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
|
| 418 |
+
|
| 419 |
+
"""
|
| 420 |
+
assert output_fmt in ('NCHWD', 'NLC'), 'Output format must be one of NCHWD or NLC.'
|
| 421 |
+
reshape = output_fmt == 'NCHWD'
|
| 422 |
+
intermediates = []
|
| 423 |
+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
| 424 |
+
|
| 425 |
+
# forward pass
|
| 426 |
+
B, _, height, width, depth = x.shape
|
| 427 |
+
x = self.patch_embed(x)
|
| 428 |
+
x, rope = self._pos_embed(x)
|
| 429 |
+
x = self.patch_drop(x)
|
| 430 |
+
x = self.norm_pre(x)
|
| 431 |
+
|
| 432 |
+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
| 433 |
+
blocks = self.blocks
|
| 434 |
+
else:
|
| 435 |
+
blocks = self.blocks[:max_index + 1]
|
| 436 |
+
for i, blk in enumerate(blocks):
|
| 437 |
+
x = blk(x, rope=rope)
|
| 438 |
+
if i in take_indices:
|
| 439 |
+
# normalize intermediates with final norm layer if enabled
|
| 440 |
+
intermediates.append(self.norm(x) if norm else x)
|
| 441 |
+
|
| 442 |
+
# process intermediates
|
| 443 |
+
if self.num_prefix_tokens:
|
| 444 |
+
# split prefix (e.g. class, distill) and spatial feature tokens
|
| 445 |
+
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
| 446 |
+
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
| 447 |
+
else:
|
| 448 |
+
prefix_tokens = None
|
| 449 |
+
|
| 450 |
+
if reshape:
|
| 451 |
+
# reshape to BCHW output format
|
| 452 |
+
H, W, D = self.patch_embed.dynamic_feat_size((height, width, depth))
|
| 453 |
+
intermediates = [y.reshape(B, H, W, D, -1).permute(0, 4, 1, 2, 3).contiguous() for y in intermediates]
|
| 454 |
+
|
| 455 |
+
if output_dict:
|
| 456 |
+
result_dict = {}
|
| 457 |
+
# Intermediates are always included
|
| 458 |
+
result_dict['image_intermediates'] = intermediates
|
| 459 |
+
if prefix_tokens is not None and return_prefix_tokens:
|
| 460 |
+
result_dict['image_intermediates_prefix'] = prefix_tokens
|
| 461 |
+
|
| 462 |
+
# Only include features if not intermediates_only
|
| 463 |
+
if not intermediates_only:
|
| 464 |
+
x_final = self.norm(x)
|
| 465 |
+
result_dict['image_features'] = x_final
|
| 466 |
+
|
| 467 |
+
return result_dict
|
| 468 |
+
|
| 469 |
+
# For non-dictionary output, maintain the original behavior
|
| 470 |
+
if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
|
| 471 |
+
# return_prefix not support in torchscript due to poor type handling
|
| 472 |
+
intermediates = list(zip(intermediates, prefix_tokens))
|
| 473 |
+
|
| 474 |
+
if intermediates_only:
|
| 475 |
+
return intermediates
|
| 476 |
+
|
| 477 |
+
x = self.norm(x)
|
| 478 |
+
|
| 479 |
+
return x, intermediates
|
| 480 |
+
|
| 481 |
+
def prune_intermediate_layers(
|
| 482 |
+
self,
|
| 483 |
+
indices: Union[int, List[int]] = 1,
|
| 484 |
+
prune_norm: bool = False,
|
| 485 |
+
prune_head: bool = True,
|
| 486 |
+
):
|
| 487 |
+
"""Prune layers not required for specified intermediates.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
indices: Indices of intermediate layers to keep.
|
| 491 |
+
prune_norm: Whether to prune normalization layer.
|
| 492 |
+
prune_head: Whether to prune the classifier head.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
List of indices that were kept.
|
| 496 |
+
"""
|
| 497 |
+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
| 498 |
+
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
| 499 |
+
if prune_norm:
|
| 500 |
+
self.norm = nn.Identity()
|
| 501 |
+
if prune_head:
|
| 502 |
+
self.fc_norm = nn.Identity()
|
| 503 |
+
self.reset_classifier(0, '')
|
| 504 |
+
return take_indices
|
| 505 |
+
|
| 506 |
+
def get_intermediate_layers(
|
| 507 |
+
self,
|
| 508 |
+
x: torch.Tensor,
|
| 509 |
+
n: Union[int, List[int], Tuple[int]] = 1,
|
| 510 |
+
reshape: bool = False,
|
| 511 |
+
return_prefix_tokens: bool = False,
|
| 512 |
+
norm: bool = False,
|
| 513 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 514 |
+
"""Get intermediate layer outputs (DINO interface compatibility).
|
| 515 |
+
|
| 516 |
+
NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
x: Input tensor.
|
| 520 |
+
n: Number or indices of layers.
|
| 521 |
+
reshape: Reshape to NCHWD format.
|
| 522 |
+
return_prefix_tokens: Return prefix tokens.
|
| 523 |
+
norm: Apply normalization.
|
| 524 |
+
|
| 525 |
+
Returns:
|
| 526 |
+
List of intermediate features.
|
| 527 |
+
"""
|
| 528 |
+
return self.forward_intermediates(
|
| 529 |
+
x, n,
|
| 530 |
+
return_prefix_tokens=return_prefix_tokens,
|
| 531 |
+
norm=norm,
|
| 532 |
+
output_fmt='NCHWD' if reshape else 'NLC',
|
| 533 |
+
intermediates_only=True,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 537 |
+
"""Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm)."""
|
| 538 |
+
x = self.patch_embed(x)
|
| 539 |
+
x, rope = self._pos_embed(x)
|
| 540 |
+
x = self.patch_drop(x)
|
| 541 |
+
x = self.norm_pre(x)
|
| 542 |
+
|
| 543 |
+
for blk in self.blocks:
|
| 544 |
+
x = blk(x, rope=rope)
|
| 545 |
+
x = self.norm(x)
|
| 546 |
+
return x
|
| 547 |
+
|
| 548 |
+
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
|
| 549 |
+
"""Apply pooling to feature tokens.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
x: Feature tensor.
|
| 553 |
+
pool_type: Pooling type override.
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
Pooled features.
|
| 557 |
+
"""
|
| 558 |
+
if self.attn_pool is not None:
|
| 559 |
+
x = self.attn_pool(x)
|
| 560 |
+
return x
|
| 561 |
+
pool_type = self.global_pool if pool_type is None else pool_type
|
| 562 |
+
x = global_pool_nlc(
|
| 563 |
+
x,
|
| 564 |
+
pool_type=pool_type,
|
| 565 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
| 566 |
+
)
|
| 567 |
+
return x
|
| 568 |
+
|
| 569 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
| 570 |
+
"""Forward pass through classifier head.
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
x: Feature tensor.
|
| 574 |
+
pre_logits: Return features before final classifier.
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
Output tensor.
|
| 578 |
+
"""
|
| 579 |
+
x = self.pool(x)
|
| 580 |
+
x = self.fc_norm(x)
|
| 581 |
+
x = self.head_drop(x)
|
| 582 |
+
return x if pre_logits else self.head(x)
|
| 583 |
+
|
| 584 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 585 |
+
x = self.forward_features(x)
|
| 586 |
+
x = self.forward_head(x)
|
| 587 |
+
return x
|
| 588 |
+
|
| 589 |
+
@classmethod
|
| 590 |
+
def from_pretrained(
|
| 591 |
+
cls,
|
| 592 |
+
checkpoint_path_or_url: Union[str, os.PathLike],
|
| 593 |
+
verbose: bool = True,
|
| 594 |
+
**kwargs
|
| 595 |
+
) -> 'VisionTransformer':
|
| 596 |
+
"""Load pretrained model weights from a local path or a URL."""
|
| 597 |
+
model = cls(**kwargs)
|
| 598 |
+
|
| 599 |
+
def _is_url(path: str) -> bool:
|
| 600 |
+
try:
|
| 601 |
+
parsed = urlparse(str(path))
|
| 602 |
+
return parsed.scheme in ('http', 'https')
|
| 603 |
+
except Exception:
|
| 604 |
+
return False
|
| 605 |
+
|
| 606 |
+
def _is_hf_url(path: str) -> bool:
|
| 607 |
+
try:
|
| 608 |
+
parsed = urlparse(str(path))
|
| 609 |
+
return 'huggingface.co' in parsed.netloc
|
| 610 |
+
except Exception:
|
| 611 |
+
return False
|
| 612 |
+
|
| 613 |
+
if _is_hf_url(checkpoint_path_or_url):
|
| 614 |
+
if verbose:
|
| 615 |
+
print(f"Downloading pretrained weights from Hugging Face URL: {checkpoint_path_or_url}")
|
| 616 |
+
# Extract repo_id and filename from the URL
|
| 617 |
+
parsed = urlparse(checkpoint_path_or_url)
|
| 618 |
+
parts = parsed.path.strip('/').split('/')
|
| 619 |
+
repo_id = '/'.join(parts[:2]) # e.g., 'cclaess/SPECTRE'
|
| 620 |
+
filename = parts[-1] # e.g., 'spectre_backbone_vit_large_patch16_128.pt'
|
| 621 |
+
|
| 622 |
+
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 623 |
+
state_dict = load_state_dict_from_file(local_path, map_location='cpu')
|
| 624 |
+
elif _is_url(checkpoint_path_or_url):
|
| 625 |
+
if verbose:
|
| 626 |
+
print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
|
| 627 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 628 |
+
checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
|
| 629 |
+
else:
|
| 630 |
+
local_path = os.fspath(checkpoint_path_or_url)
|
| 631 |
+
if not os.path.exists(local_path):
|
| 632 |
+
raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
|
| 633 |
+
if verbose:
|
| 634 |
+
print(f"Loading checkpoint from local path: {local_path}")
|
| 635 |
+
state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
|
| 636 |
+
|
| 637 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 638 |
+
if verbose:
|
| 639 |
+
print(f"Loaded pretrained weights with msg: {msg}")
|
| 640 |
+
return model
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def vit_tiny_patch16_128(
|
| 644 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 645 |
+
**kwargs
|
| 646 |
+
) -> VisionTransformer:
|
| 647 |
+
"""ViT-Tiny model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
|
| 648 |
+
"""
|
| 649 |
+
kwargs = dict(
|
| 650 |
+
img_size=(128, 128, 64),
|
| 651 |
+
patch_size=(16, 16, 8),
|
| 652 |
+
embed_dim=192,
|
| 653 |
+
depth=12,
|
| 654 |
+
num_heads=2,
|
| 655 |
+
mlp_ratio=4,
|
| 656 |
+
qkv_bias=True,
|
| 657 |
+
norm_layer=nn.LayerNorm,
|
| 658 |
+
**kwargs,
|
| 659 |
+
)
|
| 660 |
+
if checkpoint_path_or_url is not None:
|
| 661 |
+
return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 662 |
+
return VisionTransformer(**kwargs)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def vit_small_patch16_128(
|
| 666 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 667 |
+
**kwargs
|
| 668 |
+
) -> VisionTransformer:
|
| 669 |
+
"""ViT-Small model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
|
| 670 |
+
"""
|
| 671 |
+
kwargs = dict(
|
| 672 |
+
img_size=(128, 128, 64),
|
| 673 |
+
patch_size=(16, 16, 8),
|
| 674 |
+
embed_dim=384,
|
| 675 |
+
depth=12,
|
| 676 |
+
num_heads=4,
|
| 677 |
+
mlp_ratio=4,
|
| 678 |
+
qkv_bias=True,
|
| 679 |
+
norm_layer=nn.LayerNorm,
|
| 680 |
+
**kwargs,
|
| 681 |
+
)
|
| 682 |
+
if checkpoint_path_or_url is not None:
|
| 683 |
+
return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 684 |
+
return VisionTransformer(**kwargs)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def vit_base_patch16_128(
|
| 688 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 689 |
+
**kwargs
|
| 690 |
+
) -> VisionTransformer:
|
| 691 |
+
"""ViT-Base model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
|
| 692 |
+
"""
|
| 693 |
+
kwargs = dict(
|
| 694 |
+
img_size=(128, 128, 64),
|
| 695 |
+
patch_size=(16, 16, 8),
|
| 696 |
+
embed_dim=768,
|
| 697 |
+
depth=12,
|
| 698 |
+
num_heads=8,
|
| 699 |
+
mlp_ratio=4,
|
| 700 |
+
qkv_bias=True,
|
| 701 |
+
norm_layer=nn.LayerNorm,
|
| 702 |
+
**kwargs,
|
| 703 |
+
)
|
| 704 |
+
if checkpoint_path_or_url is not None:
|
| 705 |
+
return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 706 |
+
return VisionTransformer(**kwargs)
|
| 707 |
+
|
| 708 |
+
def vit_base_patch16_256(
|
| 709 |
+
pretrained_weights: Optional[str] = None,
|
| 710 |
+
**kwargs
|
| 711 |
+
) -> VisionTransformer:
|
| 712 |
+
"""ViT-Base model with 3D patch embedding, patch size [16, 16, 8] and input size [256, 256, 128].
|
| 713 |
+
"""
|
| 714 |
+
kwargs = dict(
|
| 715 |
+
img_size=(256, 256, 128),
|
| 716 |
+
patch_size=(16, 16, 8),
|
| 717 |
+
embed_dim=768,
|
| 718 |
+
depth=12,
|
| 719 |
+
num_heads=8,
|
| 720 |
+
mlp_ratio=4,
|
| 721 |
+
qkv_bias=True,
|
| 722 |
+
norm_layer=nn.LayerNorm,
|
| 723 |
+
**kwargs,
|
| 724 |
+
)
|
| 725 |
+
if pretrained_weights is not None:
|
| 726 |
+
return VisionTransformer.from_pretrained(pretrained_weights, **kwargs)
|
| 727 |
+
return VisionTransformer(**kwargs)
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def vit_base_patch32_128(
|
| 731 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 732 |
+
**kwargs
|
| 733 |
+
) -> VisionTransformer:
|
| 734 |
+
"""ViT-Base model with 3D patch embedding, patch size [32, 32, 16] and input size [128, 128, 64].
|
| 735 |
+
"""
|
| 736 |
+
kwargs = dict(
|
| 737 |
+
img_size=(128, 128, 64),
|
| 738 |
+
patch_size=(32, 32, 16),
|
| 739 |
+
embed_dim=768,
|
| 740 |
+
depth=12,
|
| 741 |
+
num_heads=8,
|
| 742 |
+
mlp_ratio=4,
|
| 743 |
+
qkv_bias=True,
|
| 744 |
+
norm_layer=nn.LayerNorm,
|
| 745 |
+
**kwargs,
|
| 746 |
+
)
|
| 747 |
+
if checkpoint_path_or_url is not None:
|
| 748 |
+
return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 749 |
+
return VisionTransformer(**kwargs)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def vit_large_patch16_128(
|
| 753 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 754 |
+
**kwargs
|
| 755 |
+
) -> VisionTransformer:
|
| 756 |
+
"""ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
|
| 757 |
+
"""
|
| 758 |
+
kwargs = dict(
|
| 759 |
+
img_size=(128, 128, 64),
|
| 760 |
+
patch_size=(16, 16, 8),
|
| 761 |
+
embed_dim=1080,
|
| 762 |
+
depth=24,
|
| 763 |
+
num_heads=12,
|
| 764 |
+
mlp_ratio=4,
|
| 765 |
+
qkv_bias=True,
|
| 766 |
+
norm_layer=nn.LayerNorm,
|
| 767 |
+
**kwargs,
|
| 768 |
+
)
|
| 769 |
+
if checkpoint_path_or_url is not None:
|
| 770 |
+
return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 771 |
+
return VisionTransformer(**kwargs)
|
| 772 |
+
|
| 773 |
+
def vit_large_patch16_256(
|
| 774 |
+
pretrained_weights: Optional[str] = None,
|
| 775 |
+
**kwargs
|
| 776 |
+
) -> VisionTransformer:
|
| 777 |
+
"""ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
|
| 778 |
+
"""
|
| 779 |
+
kwargs = dict(
|
| 780 |
+
img_size=(256, 256, 128),
|
| 781 |
+
patch_size=(16, 16, 8),
|
| 782 |
+
embed_dim=1080,
|
| 783 |
+
depth=24,
|
| 784 |
+
num_heads=12,
|
| 785 |
+
mlp_ratio=4,
|
| 786 |
+
qkv_bias=True,
|
| 787 |
+
norm_layer=nn.LayerNorm,
|
| 788 |
+
**kwargs,
|
| 789 |
+
)
|
| 790 |
+
if pretrained_weights is not None:
|
| 791 |
+
return VisionTransformer.from_pretrained(pretrained_weights, **kwargs)
|
| 792 |
+
return VisionTransformer(**kwargs)
|
| 793 |
+
|
| 794 |
+
def vit_large_patch16_320(
|
| 795 |
+
pretrained_weights: Optional[str] = None,
|
| 796 |
+
**kwargs
|
| 797 |
+
) -> VisionTransformer:
|
| 798 |
+
"""ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [320, 320, 128].
|
| 799 |
+
"""
|
| 800 |
+
kwargs = dict(
|
| 801 |
+
img_size=(320, 320, 128),
|
| 802 |
+
patch_size=(16, 16, 8),
|
| 803 |
+
embed_dim=1080,
|
| 804 |
+
depth=24,
|
| 805 |
+
num_heads=12,
|
| 806 |
+
mlp_ratio=4,
|
| 807 |
+
qkv_bias=True,
|
| 808 |
+
norm_layer=nn.LayerNorm,
|
| 809 |
+
**kwargs,
|
| 810 |
+
)
|
| 811 |
+
if pretrained_weights is not None:
|
| 812 |
+
return VisionTransformer.from_pretrained(pretrained_weights, **kwargs)
|
| 813 |
+
return VisionTransformer(**kwargs)
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def vit_large_patch32_128(
|
| 817 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 818 |
+
**kwargs
|
| 819 |
+
) -> VisionTransformer:
|
| 820 |
+
"""ViT-Large model with 3D patch embedding, patch size [32, 32, 16] and input size [128, 128, 64].
|
| 821 |
+
"""
|
| 822 |
+
kwargs = dict(
|
| 823 |
+
img_size=(128, 128, 64),
|
| 824 |
+
patch_size=(32, 32, 16),
|
| 825 |
+
embed_dim=1080,
|
| 826 |
+
depth=24,
|
| 827 |
+
num_heads=12,
|
| 828 |
+
mlp_ratio=4,
|
| 829 |
+
qkv_bias=True,
|
| 830 |
+
norm_layer=nn.LayerNorm,
|
| 831 |
+
**kwargs,
|
| 832 |
+
)
|
| 833 |
+
if checkpoint_path_or_url is not None:
|
| 834 |
+
return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 835 |
+
return VisionTransformer(**kwargs)
|
spectre/models/vision_transformer_features.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
from functools import partial
|
| 4 |
+
from urllib.parse import urlparse
|
| 5 |
+
from typing import Union, Callable, Literal, Optional, Type, Set, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from timm.models.vision_transformer import Mlp
|
| 10 |
+
from timm.layers import PatchDropout, AttentionPoolLatent
|
| 11 |
+
from huggingface_hub import hf_hub_download, load_state_dict_from_file
|
| 12 |
+
|
| 13 |
+
from spectre.utils import global_pool_nlc, to_3tuple, resample_abs_pos_embed
|
| 14 |
+
from spectre.models.vision_transformer import Block
|
| 15 |
+
from spectre.models.layers import RotaryPositionEmbedding
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FeatureVisionTransformer(nn.Module):
|
| 19 |
+
""" Vision Transformer that accepts flattened patches as input.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
|
| 24 |
+
patch_dim: int = 768,
|
| 25 |
+
num_classes: int = 1000,
|
| 26 |
+
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
|
| 27 |
+
embed_dim: int = 768,
|
| 28 |
+
depth: int = 12,
|
| 29 |
+
num_heads: int = 12,
|
| 30 |
+
attn_mode: str = 'mha',
|
| 31 |
+
q_proj_dim: Optional[int] = None,
|
| 32 |
+
kv_proj_dim: Optional[int] = None,
|
| 33 |
+
mlp_ratio: float = 4.,
|
| 34 |
+
qkv_bias: bool = True,
|
| 35 |
+
qk_norm: bool = False,
|
| 36 |
+
proj_bias: bool = True,
|
| 37 |
+
init_values: Optional[float] = None,
|
| 38 |
+
class_token: bool = True,
|
| 39 |
+
pos_embed: str = 'learn',
|
| 40 |
+
no_embed_class: bool = False,
|
| 41 |
+
rope_kwargs: Optional[dict] = None,
|
| 42 |
+
reg_tokens: int = 0,
|
| 43 |
+
pre_norm: bool = False,
|
| 44 |
+
final_norm: bool = True,
|
| 45 |
+
fc_norm: Optional[bool] = None,
|
| 46 |
+
dynamic_grid_size: bool = False,
|
| 47 |
+
drop_rate: float = 0.,
|
| 48 |
+
pos_drop_rate: float = 0.,
|
| 49 |
+
patch_drop_rate: float = 0.,
|
| 50 |
+
proj_drop_rate: float = 0.,
|
| 51 |
+
attn_drop_rate: float = 0.,
|
| 52 |
+
drop_path_rate: float = 0.,
|
| 53 |
+
norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
|
| 54 |
+
act_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
|
| 55 |
+
block_fn: Type[nn.Module] = Block,
|
| 56 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Args:
|
| 60 |
+
num_patches: Number of patches in the input.
|
| 61 |
+
patch_dim: Dimension of each flattened input patch.
|
| 62 |
+
num_classes: Number of classes for classification head.
|
| 63 |
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
| 64 |
+
embed_dim: Transformer embedding dimension.
|
| 65 |
+
depth: Depth of transformer.
|
| 66 |
+
num_heads: Number of attention heads.
|
| 67 |
+
attn_mode: Attention mode ('mha', 'mqa', 'mla').
|
| 68 |
+
q_proj_dim: Query projection dimension for 'mla' mode.
|
| 69 |
+
kv_proj_dim: Key, value projection dimension for 'mla' mode.
|
| 70 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
| 71 |
+
qkv_bias: Enable bias for qkv projections if True.
|
| 72 |
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
| 73 |
+
class_token: Use class token.
|
| 74 |
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
| 75 |
+
reg_tokens: Number of register tokens.
|
| 76 |
+
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
|
| 77 |
+
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
|
| 78 |
+
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
| 79 |
+
drop_rate: Head dropout rate.
|
| 80 |
+
pos_drop_rate: Position embedding dropout rate.
|
| 81 |
+
attn_drop_rate: Attention dropout rate.
|
| 82 |
+
drop_path_rate: Stochastic depth rate.
|
| 83 |
+
weight_init: Weight initialization scheme.
|
| 84 |
+
fix_init: Apply weight initialization fix (scaling w/ layer index).
|
| 85 |
+
norm_layer: Normalization layer.
|
| 86 |
+
act_layer: MLP activation layer.
|
| 87 |
+
block_fn: Transformer block layer.
|
| 88 |
+
"""
|
| 89 |
+
super().__init__()
|
| 90 |
+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
| 91 |
+
assert class_token or global_pool != 'token'
|
| 92 |
+
assert pos_embed in ('', 'none', 'learn', 'rope')
|
| 93 |
+
assert attn_mode in ('mha', 'mqa', 'mla')
|
| 94 |
+
assert grid_size is not None or pos_embed in ('', 'none', 'rope')
|
| 95 |
+
rope_kwargs = {} if rope_kwargs is None else dict(rope_kwargs)
|
| 96 |
+
rope_kwargs.setdefault("dtype", torch.float32) # robust with mixed-precision
|
| 97 |
+
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
|
| 98 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 99 |
+
act_layer = act_layer or nn.GELU
|
| 100 |
+
|
| 101 |
+
self.grid_size = None if grid_size is None else to_3tuple(grid_size)
|
| 102 |
+
self.num_classes = num_classes
|
| 103 |
+
self.global_pool = global_pool
|
| 104 |
+
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
|
| 105 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
| 106 |
+
self.num_prefix_tokens += reg_tokens
|
| 107 |
+
self.num_reg_tokens = reg_tokens
|
| 108 |
+
self.has_class_token = class_token
|
| 109 |
+
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
| 110 |
+
self.dynamic_grid_size = dynamic_grid_size
|
| 111 |
+
|
| 112 |
+
self.num_patches = None if grid_size is None else int(math.prod(grid_size))
|
| 113 |
+
self.patch_proj = nn.Linear(patch_dim, embed_dim, proj_bias)
|
| 114 |
+
|
| 115 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
| 116 |
+
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
| 117 |
+
self.pos_embed, self.rope, self.requires_per_sample_rope = None, None, False
|
| 118 |
+
if pos_embed == 'learn':
|
| 119 |
+
embed_len = self.num_patches if no_embed_class else self.num_patches + self.num_prefix_tokens
|
| 120 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
| 121 |
+
if pos_embed == 'rope':
|
| 122 |
+
self.rope = RotaryPositionEmbedding(
|
| 123 |
+
embed_dim=embed_dim,
|
| 124 |
+
num_heads=num_heads,
|
| 125 |
+
**rope_kwargs,
|
| 126 |
+
)
|
| 127 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
| 128 |
+
if patch_drop_rate > 0:
|
| 129 |
+
self.patch_drop = PatchDropout(
|
| 130 |
+
patch_drop_rate,
|
| 131 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
self.patch_drop = nn.Identity()
|
| 135 |
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
| 136 |
+
|
| 137 |
+
dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule
|
| 138 |
+
self.blocks = nn.Sequential(*[
|
| 139 |
+
block_fn(
|
| 140 |
+
dim=embed_dim,
|
| 141 |
+
num_heads=num_heads,
|
| 142 |
+
attn_mode=attn_mode,
|
| 143 |
+
q_proj_dim=q_proj_dim,
|
| 144 |
+
kv_proj_dim=kv_proj_dim,
|
| 145 |
+
mlp_ratio=mlp_ratio,
|
| 146 |
+
qkv_bias=qkv_bias,
|
| 147 |
+
qk_norm=qk_norm,
|
| 148 |
+
proj_bias=proj_bias,
|
| 149 |
+
init_values=init_values,
|
| 150 |
+
proj_drop=proj_drop_rate,
|
| 151 |
+
attn_drop=attn_drop_rate,
|
| 152 |
+
drop_path=dpr[i],
|
| 153 |
+
norm_layer=norm_layer,
|
| 154 |
+
act_layer=act_layer,
|
| 155 |
+
mlp_layer=mlp_layer,
|
| 156 |
+
)
|
| 157 |
+
for i in range(depth)])
|
| 158 |
+
self.feature_info = [
|
| 159 |
+
dict(module=f'blocks.{i}', num_chs=embed_dim) for i in range(depth)]
|
| 160 |
+
self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
|
| 161 |
+
|
| 162 |
+
# Classifier Head
|
| 163 |
+
if global_pool == 'map':
|
| 164 |
+
self.attn_pool = AttentionPoolLatent(
|
| 165 |
+
self.embed_dim,
|
| 166 |
+
num_heads=num_heads,
|
| 167 |
+
mlp_ratio=mlp_ratio,
|
| 168 |
+
norm_layer=norm_layer,
|
| 169 |
+
act_layer=act_layer,
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
self.attn_pool = None
|
| 173 |
+
self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
|
| 174 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 175 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 176 |
+
|
| 177 |
+
self.init_weights()
|
| 178 |
+
|
| 179 |
+
def init_weights(self) -> None:
|
| 180 |
+
if self.pos_embed is not None and not self.pos_embed.is_meta:
|
| 181 |
+
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
| 182 |
+
if self.cls_token is not None and not self.cls_token.is_meta:
|
| 183 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 184 |
+
if self.reg_token is not None and not self.reg_token.is_meta:
|
| 185 |
+
nn.init.normal_(self.reg_token, std=1e-6)
|
| 186 |
+
self.apply(self._init_weights)
|
| 187 |
+
|
| 188 |
+
def _init_weights(self, m: nn.Module) -> None:
|
| 189 |
+
# this fn left here for compat with downstream users
|
| 190 |
+
if isinstance(m, nn.Linear):
|
| 191 |
+
if not m.weight.is_meta:
|
| 192 |
+
nn.init.trunc_normal_(m.weight, std=.02)
|
| 193 |
+
if m.bias is not None and not m.bias.is_meta:
|
| 194 |
+
nn.init.zeros_(m.bias)
|
| 195 |
+
|
| 196 |
+
@torch.jit.ignore
|
| 197 |
+
def no_weight_decay(self) -> Set:
|
| 198 |
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
| 199 |
+
|
| 200 |
+
@torch.jit.ignore
|
| 201 |
+
def get_classifier(self) -> nn.Module:
|
| 202 |
+
return self.head
|
| 203 |
+
|
| 204 |
+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
| 205 |
+
self.num_classes = num_classes
|
| 206 |
+
if global_pool is not None:
|
| 207 |
+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
| 208 |
+
if global_pool == 'map' and self.attn_pool is None:
|
| 209 |
+
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
| 210 |
+
elif global_pool != 'map' and self.attn_pool is not None:
|
| 211 |
+
self.attn_pool = None # remove attention pooling
|
| 212 |
+
self.global_pool = global_pool
|
| 213 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 214 |
+
|
| 215 |
+
def _pos_embed(
|
| 216 |
+
self,
|
| 217 |
+
x: torch.Tensor,
|
| 218 |
+
grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
|
| 219 |
+
):
|
| 220 |
+
if self.pos_embed is None and self.rope is None:
|
| 221 |
+
x = x.view(x.shape[0], -1, x.shape[-1])
|
| 222 |
+
if self.reg_token is not None:
|
| 223 |
+
x = torch.cat([self.reg_token.expand(x.shape[0], -1, -1), x], dim=1)
|
| 224 |
+
if self.cls_token is not None:
|
| 225 |
+
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
|
| 226 |
+
return x, None
|
| 227 |
+
|
| 228 |
+
if self.dynamic_grid_size or self.rope is not None:
|
| 229 |
+
assert grid_size is not None, "grid_size must be provided when using dynamic_grid_size or RoPE."
|
| 230 |
+
|
| 231 |
+
pos_embed, rope = None, None
|
| 232 |
+
if self.pos_embed is not None:
|
| 233 |
+
if self.dynamic_grid_size:
|
| 234 |
+
H, W, D = to_3tuple(grid_size)
|
| 235 |
+
prev_grid_size = self.grid_size
|
| 236 |
+
pos_embed = resample_abs_pos_embed(
|
| 237 |
+
self.pos_embed,
|
| 238 |
+
new_size=(H, W, D),
|
| 239 |
+
old_size=prev_grid_size,
|
| 240 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
pos_embed = self.pos_embed
|
| 244 |
+
|
| 245 |
+
if self.rope is not None:
|
| 246 |
+
B = x.shape[0]
|
| 247 |
+
H, W, D = to_3tuple(grid_size)
|
| 248 |
+
if self.requires_per_sample_rope:
|
| 249 |
+
rope = [self.rope(H=H, W=W, D=D) for _ in range(B)]
|
| 250 |
+
else:
|
| 251 |
+
rope = self.rope(H=H, W=W, D=D)
|
| 252 |
+
|
| 253 |
+
to_cat = []
|
| 254 |
+
if self.cls_token is not None:
|
| 255 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
| 256 |
+
if self.reg_token is not None:
|
| 257 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
| 258 |
+
|
| 259 |
+
if self.no_embed_class:
|
| 260 |
+
# deit-3, updated JAX (big vision)
|
| 261 |
+
# position embedding does not overlap with class token, add then concat
|
| 262 |
+
if pos_embed is not None:
|
| 263 |
+
x = x + pos_embed
|
| 264 |
+
if to_cat:
|
| 265 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 266 |
+
else:
|
| 267 |
+
# original timm, JAX, and deit vit impl
|
| 268 |
+
# pos_embed has entry for class token, concat then add
|
| 269 |
+
if to_cat:
|
| 270 |
+
x = torch.cat(to_cat + [x], dim=1)
|
| 271 |
+
if pos_embed is not None:
|
| 272 |
+
x = x + pos_embed
|
| 273 |
+
|
| 274 |
+
return self.pos_drop(x), rope
|
| 275 |
+
|
| 276 |
+
def forward_features(
|
| 277 |
+
self,
|
| 278 |
+
x: torch.Tensor,
|
| 279 |
+
grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
|
| 280 |
+
) -> torch.Tensor:
|
| 281 |
+
assert x.ndim == 3, f"Expected input with 3 dimensions (B, N, C), got {x.ndim}."
|
| 282 |
+
|
| 283 |
+
x = self.patch_proj(x)
|
| 284 |
+
x, rope = self._pos_embed(x, grid_size)
|
| 285 |
+
x = self.patch_drop(x)
|
| 286 |
+
x = self.norm_pre(x)
|
| 287 |
+
|
| 288 |
+
for blk in self.blocks:
|
| 289 |
+
x = blk(x, rope=rope)
|
| 290 |
+
x = self.norm(x)
|
| 291 |
+
return x
|
| 292 |
+
|
| 293 |
+
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
|
| 294 |
+
if self.attn_pool is not None:
|
| 295 |
+
x = self.attn_pool(x)
|
| 296 |
+
return x
|
| 297 |
+
pool_type = self.global_pool if pool_type is None else pool_type
|
| 298 |
+
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
|
| 299 |
+
return x
|
| 300 |
+
|
| 301 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
| 302 |
+
x = self.pool(x)
|
| 303 |
+
x = self.fc_norm(x)
|
| 304 |
+
x = self.head_drop(x)
|
| 305 |
+
return x if pre_logits else self.head(x)
|
| 306 |
+
|
| 307 |
+
def forward(
|
| 308 |
+
self,
|
| 309 |
+
x: torch.Tensor,
|
| 310 |
+
grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
|
| 311 |
+
) -> torch.Tensor:
|
| 312 |
+
x = self.forward_features(x, grid_size)
|
| 313 |
+
x = self.forward_head(x)
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
@classmethod
|
| 317 |
+
def from_pretrained(
|
| 318 |
+
cls,
|
| 319 |
+
checkpoint_path_or_url: Union[str, os.PathLike],
|
| 320 |
+
verbose: bool = True,
|
| 321 |
+
**kwargs
|
| 322 |
+
) -> 'FeatureVisionTransformer':
|
| 323 |
+
"""Load pretrained model weights from a local path or a URL."""
|
| 324 |
+
model = cls(**kwargs)
|
| 325 |
+
|
| 326 |
+
def _is_url(path: str) -> bool:
|
| 327 |
+
try:
|
| 328 |
+
parsed = urlparse(str(path))
|
| 329 |
+
return parsed.scheme in ('http', 'https')
|
| 330 |
+
except Exception:
|
| 331 |
+
return False
|
| 332 |
+
|
| 333 |
+
def _is_hf_url(path: str) -> bool:
|
| 334 |
+
try:
|
| 335 |
+
parsed = urlparse(str(path))
|
| 336 |
+
return 'huggingface.co' in parsed.netloc
|
| 337 |
+
except Exception:
|
| 338 |
+
return False
|
| 339 |
+
|
| 340 |
+
if _is_hf_url(checkpoint_path_or_url):
|
| 341 |
+
if verbose:
|
| 342 |
+
print(f"Downloading pretrained weights from Hugging Face URL: {checkpoint_path_or_url}")
|
| 343 |
+
# Extract repo_id and filename from the URL
|
| 344 |
+
parsed = urlparse(checkpoint_path_or_url)
|
| 345 |
+
parts = parsed.path.strip('/').split('/')
|
| 346 |
+
repo_id = '/'.join(parts[:2]) # e.g., 'cclaess/SPECTRE'
|
| 347 |
+
filename = parts[-1] # e.g., 'spectre_backbone_vit_large_patch16_128.pt'
|
| 348 |
+
|
| 349 |
+
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 350 |
+
state_dict = load_state_dict_from_file(local_path, map_location='cpu')
|
| 351 |
+
elif _is_url(checkpoint_path_or_url):
|
| 352 |
+
if verbose:
|
| 353 |
+
print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
|
| 354 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 355 |
+
checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
|
| 356 |
+
else:
|
| 357 |
+
local_path = os.fspath(checkpoint_path_or_url)
|
| 358 |
+
if not os.path.exists(local_path):
|
| 359 |
+
raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
|
| 360 |
+
if verbose:
|
| 361 |
+
print(f"Loading checkpoint from local path: {local_path}")
|
| 362 |
+
state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
|
| 363 |
+
|
| 364 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 365 |
+
if verbose:
|
| 366 |
+
print(f"Loaded pretrained weights with msg: {msg}")
|
| 367 |
+
return model
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def feat_vit_tiny(
|
| 371 |
+
patch_dim,
|
| 372 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 373 |
+
**kwargs,
|
| 374 |
+
) -> FeatureVisionTransformer:
|
| 375 |
+
"""Feature ViT-Tiny model.
|
| 376 |
+
"""
|
| 377 |
+
kwargs = dict(
|
| 378 |
+
patch_dim=patch_dim,
|
| 379 |
+
embed_dim=192,
|
| 380 |
+
depth=2,
|
| 381 |
+
num_heads=2,
|
| 382 |
+
mlp_ratio=4,
|
| 383 |
+
qkv_bias=True,
|
| 384 |
+
norm_layer=nn.LayerNorm,
|
| 385 |
+
**kwargs,
|
| 386 |
+
)
|
| 387 |
+
if checkpoint_path_or_url is not None:
|
| 388 |
+
return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 389 |
+
return FeatureVisionTransformer(**kwargs)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def feat_vit_small(
|
| 393 |
+
patch_dim,
|
| 394 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 395 |
+
**kwargs,
|
| 396 |
+
) -> FeatureVisionTransformer:
|
| 397 |
+
"""Feature ViT-Small model.
|
| 398 |
+
"""
|
| 399 |
+
kwargs = dict(
|
| 400 |
+
patch_dim=patch_dim,
|
| 401 |
+
embed_dim=384,
|
| 402 |
+
depth=2,
|
| 403 |
+
num_heads=4,
|
| 404 |
+
mlp_ratio=4,
|
| 405 |
+
qkv_bias=True,
|
| 406 |
+
norm_layer=nn.LayerNorm,
|
| 407 |
+
**kwargs,
|
| 408 |
+
)
|
| 409 |
+
if checkpoint_path_or_url is not None:
|
| 410 |
+
return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 411 |
+
return FeatureVisionTransformer(**kwargs)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def feat_vit_base(
|
| 415 |
+
patch_dim,
|
| 416 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 417 |
+
**kwargs,
|
| 418 |
+
) -> FeatureVisionTransformer:
|
| 419 |
+
"""Feature ViT-Base model.
|
| 420 |
+
"""
|
| 421 |
+
kwargs = dict(
|
| 422 |
+
patch_dim=patch_dim,
|
| 423 |
+
embed_dim=768,
|
| 424 |
+
depth=2,
|
| 425 |
+
num_heads=8,
|
| 426 |
+
mlp_ratio=4,
|
| 427 |
+
qkv_bias=True,
|
| 428 |
+
norm_layer=nn.LayerNorm,
|
| 429 |
+
**kwargs,
|
| 430 |
+
)
|
| 431 |
+
if checkpoint_path_or_url is not None:
|
| 432 |
+
return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 433 |
+
return FeatureVisionTransformer(**kwargs)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def feat_vit_large(
|
| 437 |
+
patch_dim,
|
| 438 |
+
checkpoint_path_or_url: Optional[str] = None,
|
| 439 |
+
**kwargs,
|
| 440 |
+
) -> FeatureVisionTransformer:
|
| 441 |
+
"""Feature ViT-Large model.
|
| 442 |
+
"""
|
| 443 |
+
kwargs = dict(
|
| 444 |
+
patch_dim=patch_dim,
|
| 445 |
+
embed_dim=1080,
|
| 446 |
+
depth=4,
|
| 447 |
+
num_heads=12,
|
| 448 |
+
mlp_ratio=4,
|
| 449 |
+
qkv_bias=True,
|
| 450 |
+
norm_layer=nn.LayerNorm,
|
| 451 |
+
**kwargs,
|
| 452 |
+
)
|
| 453 |
+
if checkpoint_path_or_url is not None:
|
| 454 |
+
return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 455 |
+
return FeatureVisionTransformer(**kwargs)
|
spectre/utils/__init__.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._utils import (
|
| 2 |
+
fix_random_seeds,
|
| 3 |
+
to_ntuple,
|
| 4 |
+
to_1tuple,
|
| 5 |
+
to_2tuple,
|
| 6 |
+
to_3tuple,
|
| 7 |
+
to_4tuple,
|
| 8 |
+
)
|
| 9 |
+
from .checkpointing import (
|
| 10 |
+
save_state,
|
| 11 |
+
load_state,
|
| 12 |
+
extract_model_from_checkpoint_dinov2,
|
| 13 |
+
extract_model_from_checkpoint_siglip,
|
| 14 |
+
)
|
| 15 |
+
from .collate import (
|
| 16 |
+
extended_collate_dino,
|
| 17 |
+
extended_collate_siglip,
|
| 18 |
+
collate_add_filenames,
|
| 19 |
+
)
|
| 20 |
+
from .config import setup
|
| 21 |
+
from .dataloader import get_dataloader
|
| 22 |
+
from .distributed import (
|
| 23 |
+
is_enabled,
|
| 24 |
+
get_global_size,
|
| 25 |
+
get_global_rank,
|
| 26 |
+
get_local_size,
|
| 27 |
+
get_local_rank,
|
| 28 |
+
init_distributed,
|
| 29 |
+
)
|
| 30 |
+
from .lora import add_lora_adapters
|
| 31 |
+
from .masking import random_block_mask
|
| 32 |
+
from .modeling import (
|
| 33 |
+
deactivate_requires_grad_and_to_eval,
|
| 34 |
+
activate_requires_grad_and_to_train,
|
| 35 |
+
update_momentum,
|
| 36 |
+
update_drop_path_rate,
|
| 37 |
+
repeat_token,
|
| 38 |
+
expand_index_like,
|
| 39 |
+
get_at_index,
|
| 40 |
+
set_at_index,
|
| 41 |
+
mask_at_index,
|
| 42 |
+
mask_bool,
|
| 43 |
+
patchify,
|
| 44 |
+
random_token_mask,
|
| 45 |
+
resample_abs_pos_embed,
|
| 46 |
+
resample_abs_pos_embed_nhwdc,
|
| 47 |
+
resample_patch_embed,
|
| 48 |
+
feature_take_indices,
|
| 49 |
+
global_pool_nlc,
|
| 50 |
+
cat_keep_shapes,
|
| 51 |
+
uncat_with_shapes,
|
| 52 |
+
last_token_pool,
|
| 53 |
+
Format,
|
| 54 |
+
nchwd_to,
|
| 55 |
+
nhwdc_to,
|
| 56 |
+
)
|
| 57 |
+
from .param_groups import get_param_groups_with_decay
|
| 58 |
+
from .scheduler import (
|
| 59 |
+
linear_warmup_schedule,
|
| 60 |
+
cosine_schedule,
|
| 61 |
+
cosine_warmup_schedule,
|
| 62 |
+
CosineWarmupScheduler,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
__all__ = [
|
| 66 |
+
"fix_random_seeds",
|
| 67 |
+
"to_ntuple",
|
| 68 |
+
"to_1tuple",
|
| 69 |
+
"to_2tuple",
|
| 70 |
+
"to_3tuple",
|
| 71 |
+
"to_4tuple",
|
| 72 |
+
"save_state",
|
| 73 |
+
"load_state",
|
| 74 |
+
"extract_model_from_checkpoint_dinov2",
|
| 75 |
+
"extract_model_from_checkpoint_siglip",
|
| 76 |
+
"extended_collate_dino",
|
| 77 |
+
"extended_collate_siglip",
|
| 78 |
+
"collate_add_filenames",
|
| 79 |
+
"setup",
|
| 80 |
+
"get_dataloader",
|
| 81 |
+
"is_enabled",
|
| 82 |
+
"get_global_size",
|
| 83 |
+
"get_global_rank",
|
| 84 |
+
"get_local_size",
|
| 85 |
+
"get_local_rank",
|
| 86 |
+
"init_distributed",
|
| 87 |
+
"add_lora_adapters",
|
| 88 |
+
"random_block_mask",
|
| 89 |
+
"deactivate_requires_grad_and_to_eval",
|
| 90 |
+
"activate_requires_grad_and_to_train",
|
| 91 |
+
"update_momentum",
|
| 92 |
+
"update_drop_path_rate",
|
| 93 |
+
"repeat_token",
|
| 94 |
+
"expand_index_like",
|
| 95 |
+
"get_at_index",
|
| 96 |
+
"set_at_index",
|
| 97 |
+
"mask_at_index",
|
| 98 |
+
"mask_bool",
|
| 99 |
+
"patchify",
|
| 100 |
+
"random_token_mask",
|
| 101 |
+
"resample_abs_pos_embed",
|
| 102 |
+
"resample_abs_pos_embed_nhwdc",
|
| 103 |
+
"resample_patch_embed",
|
| 104 |
+
"feature_take_indices",
|
| 105 |
+
"global_pool_nlc",
|
| 106 |
+
"cat_keep_shapes",
|
| 107 |
+
"uncat_with_shapes",
|
| 108 |
+
"last_token_pool",
|
| 109 |
+
"Format",
|
| 110 |
+
"nchwd_to",
|
| 111 |
+
"nhwdc_to",
|
| 112 |
+
"get_param_groups_with_decay",
|
| 113 |
+
"linear_warmup_schedule",
|
| 114 |
+
"cosine_schedule",
|
| 115 |
+
"cosine_warmup_schedule",
|
| 116 |
+
"CosineWarmupScheduler",
|
| 117 |
+
]
|
spectre/utils/_utils.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Iterable
|
| 3 |
+
from itertools import repeat
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
MONAI_IMPORT_ERROR = None
|
| 9 |
+
try:
|
| 10 |
+
import monai
|
| 11 |
+
except ImportError as e:
|
| 12 |
+
monai = None # type: ignore
|
| 13 |
+
MONAI_IMPORT_ERROR = e
|
| 14 |
+
|
| 15 |
+
def fix_random_seeds(seed: int = 31):
|
| 16 |
+
"""
|
| 17 |
+
Fix random seeds.
|
| 18 |
+
"""
|
| 19 |
+
if MONAI_IMPORT_ERROR is not None:
|
| 20 |
+
raise ImportError(
|
| 21 |
+
"MONAI is required to use fix_random_seeds but not installed. "
|
| 22 |
+
"Please install MONAI to use this function."
|
| 23 |
+
) from MONAI_IMPORT_ERROR
|
| 24 |
+
|
| 25 |
+
torch.manual_seed(seed)
|
| 26 |
+
torch.cuda.manual_seed_all(seed)
|
| 27 |
+
np.random.seed(seed)
|
| 28 |
+
random.seed(seed)
|
| 29 |
+
torch.backends.cudnn.deterministic = True
|
| 30 |
+
torch.backends.cudnn.benchmark = False
|
| 31 |
+
monai.utils.set_determinism(seed=seed)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _ntuple(n: int):
|
| 35 |
+
"""
|
| 36 |
+
Helper function to create n-tuple.
|
| 37 |
+
"""
|
| 38 |
+
def parse(x):
|
| 39 |
+
if isinstance(x, Iterable) and not isinstance(x, str):
|
| 40 |
+
return tuple(x)
|
| 41 |
+
return tuple(repeat(x, n))
|
| 42 |
+
return parse
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
to_1tuple = _ntuple(1)
|
| 46 |
+
to_2tuple = _ntuple(2)
|
| 47 |
+
to_3tuple = _ntuple(3)
|
| 48 |
+
to_4tuple = _ntuple(4)
|
| 49 |
+
to_ntuple = _ntuple
|
spectre/utils/checkpointing.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Optional, Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_local_rng_state() -> dict:
|
| 12 |
+
"""Return a picklable dict with local RNG states (cpu & cuda, numpy, random)."""
|
| 13 |
+
state = {
|
| 14 |
+
"torch": torch.get_rng_state().cpu(),
|
| 15 |
+
"numpy": np.random.get_state(),
|
| 16 |
+
"random": random.getstate(),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
if torch.cuda.is_available():
|
| 20 |
+
# make sure CUDA states are stored on CPU so they are picklable
|
| 21 |
+
cuda_states = [s.cpu() for s in torch.cuda.get_rng_state_all()]
|
| 22 |
+
state["cuda"] = cuda_states
|
| 23 |
+
else:
|
| 24 |
+
state["cuda"] = None
|
| 25 |
+
|
| 26 |
+
return state
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _set_local_rng_state(state: dict) -> None:
|
| 30 |
+
"""Set local RNG states from the dict produced by _get_local_rng_state()."""
|
| 31 |
+
if state is None:
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
if "torch" in state and state["torch"] is not None:
|
| 35 |
+
torch.set_rng_state(state["torch"])
|
| 36 |
+
if "cuda" in state and state["cuda"] is not None and torch.cuda.is_available():
|
| 37 |
+
try:
|
| 38 |
+
# move back to CUDA tensors for this process and set them
|
| 39 |
+
cuda_states = [s.cuda() for s in state["cuda"]]
|
| 40 |
+
torch.cuda.set_rng_state_all(cuda_states)
|
| 41 |
+
except Exception:
|
| 42 |
+
# fallback: try setting per-device RNG if set_rng_state_all fails
|
| 43 |
+
for i, s in enumerate(state["cuda"]):
|
| 44 |
+
try:
|
| 45 |
+
torch.cuda.set_rng_state(s.cuda(), device=i)
|
| 46 |
+
except Exception:
|
| 47 |
+
# ignore if device mismatch
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
if "numpy" in state and state["numpy"] is not None:
|
| 51 |
+
np.random.set_state(state["numpy"])
|
| 52 |
+
if "random" in state and state["random"] is not None:
|
| 53 |
+
random.setstate(state["random"])
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def save_state(ckpt_path: str, epoch: Optional[int] = None, **named_objects: Any) -> None:
|
| 57 |
+
"""
|
| 58 |
+
Save a checkpoint that includes:
|
| 59 |
+
- epoch (optional)
|
| 60 |
+
- state_dicts for provided named_objects
|
| 61 |
+
- rng_states: list of per-rank RNG dictionaries (one entry per world rank)
|
| 62 |
+
|
| 63 |
+
If torch.distributed is initialized the RNG states from all ranks are gathered and
|
| 64 |
+
stored in checkpoint["rng_states"] (list indexed by rank). Only rank 0 writes the file.
|
| 65 |
+
In single-process mode the checkpoint contains a single-item rng_states list.
|
| 66 |
+
"""
|
| 67 |
+
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
|
| 68 |
+
|
| 69 |
+
# prepare local RNG state
|
| 70 |
+
local_rng = _get_local_rng_state()
|
| 71 |
+
|
| 72 |
+
# distributed path: gather RNG states from all ranks
|
| 73 |
+
if dist.is_available() and dist.is_initialized():
|
| 74 |
+
rank = dist.get_rank()
|
| 75 |
+
world_size = dist.get_world_size()
|
| 76 |
+
all_states = [None] * world_size
|
| 77 |
+
# gather python objects (picklable)
|
| 78 |
+
dist.all_gather_object(all_states, local_rng)
|
| 79 |
+
|
| 80 |
+
# only rank 0 writes the checkpoint file
|
| 81 |
+
if rank == 0:
|
| 82 |
+
checkpoint = {}
|
| 83 |
+
if epoch is not None:
|
| 84 |
+
checkpoint["epoch"] = epoch
|
| 85 |
+
checkpoint["rng_states"] = all_states
|
| 86 |
+
|
| 87 |
+
# save provided objects' state_dicts (rank 0's local state_dicts)
|
| 88 |
+
for name, obj in named_objects.items():
|
| 89 |
+
checkpoint[name] = obj.state_dict()
|
| 90 |
+
|
| 91 |
+
torch.save(checkpoint, ckpt_path)
|
| 92 |
+
|
| 93 |
+
# ensure everyone waits until rank 0 finished writing
|
| 94 |
+
dist.barrier()
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
# single-process fallback
|
| 98 |
+
checkpoint = {}
|
| 99 |
+
if epoch is not None:
|
| 100 |
+
checkpoint["epoch"] = epoch
|
| 101 |
+
checkpoint["rng_states"] = [local_rng]
|
| 102 |
+
for name, obj in named_objects.items():
|
| 103 |
+
checkpoint[name] = obj.state_dict()
|
| 104 |
+
torch.save(checkpoint, ckpt_path)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def load_state(ckpt_path: str, **named_objects: Any) -> int:
|
| 108 |
+
"""
|
| 109 |
+
Load checkpoint saved by save_state.
|
| 110 |
+
|
| 111 |
+
- Each process loads the same file and restores its own RNG state (checkpoint['rng_states'][rank]).
|
| 112 |
+
- Named objects that exist in the checkpoint will have their state_dict loaded.
|
| 113 |
+
- Returns epoch (int) if present, otherwise 0.
|
| 114 |
+
"""
|
| 115 |
+
if not os.path.isfile(ckpt_path):
|
| 116 |
+
warnings.warn(f"Checkpoint file not found: {ckpt_path}")
|
| 117 |
+
return 0
|
| 118 |
+
|
| 119 |
+
# load on all ranks (shared FS assumed)
|
| 120 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 121 |
+
epoch = checkpoint.get("epoch", 0)
|
| 122 |
+
|
| 123 |
+
# load state_dicts into provided objects
|
| 124 |
+
for name, obj in named_objects.items():
|
| 125 |
+
if name in checkpoint:
|
| 126 |
+
try:
|
| 127 |
+
obj.load_state_dict(checkpoint[name])
|
| 128 |
+
except Exception as e:
|
| 129 |
+
warnings.warn(f"Failed to load state_dict for '{name}': {e}")
|
| 130 |
+
else:
|
| 131 |
+
warnings.warn(f"No state_dict found for '{name}' in checkpoint.")
|
| 132 |
+
|
| 133 |
+
# restore this rank's RNG state
|
| 134 |
+
rng_states = checkpoint.get("rng_states", None)
|
| 135 |
+
if rng_states is not None:
|
| 136 |
+
if dist.is_available() and dist.is_initialized():
|
| 137 |
+
rank = dist.get_rank()
|
| 138 |
+
if rank < len(rng_states):
|
| 139 |
+
my_state = rng_states[rank]
|
| 140 |
+
else:
|
| 141 |
+
my_state = None
|
| 142 |
+
else:
|
| 143 |
+
# single-process file: first element
|
| 144 |
+
my_state = rng_states[0] if len(rng_states) > 0 else None
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
_set_local_rng_state(my_state)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
warnings.warn(f"Failed to restore RNG state: {e}")
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
warnings.warn("No 'rng_states' found in checkpoint; RNGs not restored.")
|
| 153 |
+
|
| 154 |
+
return epoch
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def extract_model_from_checkpoint_dinov2(checkpoint_path: str):
|
| 158 |
+
# Load the checkpoint
|
| 159 |
+
checkpoint = torch.load(
|
| 160 |
+
checkpoint_path,
|
| 161 |
+
weights_only=False,
|
| 162 |
+
map_location="cpu"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Get model state dict
|
| 166 |
+
model_state = checkpoint.get("model", checkpoint)
|
| 167 |
+
|
| 168 |
+
# Create output folder
|
| 169 |
+
output_dir = str(checkpoint_path).replace(".pt", "")
|
| 170 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 171 |
+
|
| 172 |
+
# Quick check: compare the parameters of head_teacher_ibot vs head_teacher_dino
|
| 173 |
+
teacher_dino_keys = [k for k in model_state.keys() if k.startswith("head_teacher_dino.")]
|
| 174 |
+
teacher_ibot_keys = [k for k in model_state.keys() if k.startswith("head_teacher_ibot.")]
|
| 175 |
+
|
| 176 |
+
ibot_separate = True
|
| 177 |
+
if teacher_dino_keys and teacher_ibot_keys:
|
| 178 |
+
if all(torch.equal(model_state[dino_key], model_state[ibot_key]) \
|
| 179 |
+
for dino_key, ibot_key in zip(teacher_dino_keys, teacher_ibot_keys)):
|
| 180 |
+
ibot_separate = False # Same weights → no separate ibot head
|
| 181 |
+
|
| 182 |
+
# Define the components to save
|
| 183 |
+
components = {
|
| 184 |
+
"backbone_teacher.pt": "backbone_teacher.vit",
|
| 185 |
+
"backbone_student.pt": "backbone_student.vit",
|
| 186 |
+
"head_student_dino.pt": "head_student_dino",
|
| 187 |
+
"head_teacher_dino.pt": "head_teacher_dino"
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
# Add ibot heads only if separate
|
| 191 |
+
if ibot_separate:
|
| 192 |
+
components["head_student_ibot.pt"] = "head_student_ibot"
|
| 193 |
+
components["head_teacher_ibot.pt"] = "head_teacher_ibot"
|
| 194 |
+
|
| 195 |
+
# Extract and save each component
|
| 196 |
+
for filename, key in components.items():
|
| 197 |
+
sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
|
| 198 |
+
if not sub_state_dict:
|
| 199 |
+
print(f"[WARNING] No parameters found for {key}, skipping...")
|
| 200 |
+
continue
|
| 201 |
+
torch.save(sub_state_dict, os.path.join(output_dir, filename))
|
| 202 |
+
|
| 203 |
+
print(f"Components extracted to: {output_dir}")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def extract_model_from_checkpoint_siglip(checkpoint_path: str):
|
| 207 |
+
# Load the checkpoint
|
| 208 |
+
checkpoint = torch.load(
|
| 209 |
+
checkpoint_path,
|
| 210 |
+
weights_only=False,
|
| 211 |
+
map_location="cpu",
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Get model state dict
|
| 215 |
+
model_state = checkpoint.get("model", checkpoint)
|
| 216 |
+
|
| 217 |
+
# Create output folder
|
| 218 |
+
output_dir = str(checkpoint_path).replace(".pt", "")
|
| 219 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 220 |
+
|
| 221 |
+
# Define the components to save
|
| 222 |
+
components = {
|
| 223 |
+
"backbone_image.pt": "backbone_image",
|
| 224 |
+
"backbone_text.pt": "backbone_text",
|
| 225 |
+
"feature_comb_image.pt": "feature_comb_image",
|
| 226 |
+
"projection_image.pt": "projection_image",
|
| 227 |
+
"projection_text.pt": "projection_text"
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
# Extract and save each component
|
| 231 |
+
for filename, key in components.items():
|
| 232 |
+
sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
|
| 233 |
+
if not sub_state_dict:
|
| 234 |
+
print(f"[WARNING] No parameters found for {key}, skipping...")
|
| 235 |
+
continue
|
| 236 |
+
torch.save(sub_state_dict, os.path.join(output_dir, filename))
|
| 237 |
+
|
| 238 |
+
print(f"Components extracted to: {output_dir}")
|
spectre/utils/collate.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Callable, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
MONAI_IMPORT_ERROR = None
|
| 6 |
+
try:
|
| 7 |
+
from monai.data import list_data_collate
|
| 8 |
+
except ImportError as e:
|
| 9 |
+
list_data_collate = lambda x: x # type: ignore
|
| 10 |
+
MONAI_IMPORT_ERROR = e
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extended_collate_dino(samples_list: List) -> dict:
|
| 14 |
+
"""
|
| 15 |
+
Applies MONAI's list_data_collate first and then extends it with DINOv2 masking logic.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
samples_list: List of samples containing 'global_crops' and 'local_crops'.
|
| 19 |
+
mask_ratio: Tuple defining the range of masking ratios.
|
| 20 |
+
mask_probability: Probability of applying masking.
|
| 21 |
+
dtype: Data type to cast the collated tensors.
|
| 22 |
+
n_tokens: Number of tokens for masking.
|
| 23 |
+
mask_generator: Function to generate masks.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
A dictionary with collated global/local crops and corresponding masks.
|
| 27 |
+
"""
|
| 28 |
+
if MONAI_IMPORT_ERROR is not None:
|
| 29 |
+
raise ImportError(
|
| 30 |
+
"MONAI is required to use extended_collate_dino but not installed. "
|
| 31 |
+
"Please install MONAI to use this collate function."
|
| 32 |
+
) from MONAI_IMPORT_ERROR
|
| 33 |
+
|
| 34 |
+
# Apply MONAI's list_data_collate
|
| 35 |
+
collated_data = list_data_collate(samples_list)
|
| 36 |
+
|
| 37 |
+
# Extract crops
|
| 38 |
+
global_views = torch.cat(collated_data["image_global_views"], dim=0)
|
| 39 |
+
local_views = torch.cat(collated_data["image_local_views"], dim=0)
|
| 40 |
+
|
| 41 |
+
return {
|
| 42 |
+
"global_views": global_views,
|
| 43 |
+
"local_views": local_views,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def extended_collate_siglip(
|
| 48 |
+
samples_list: List,
|
| 49 |
+
tokenizer: Optional[Callable] = None,
|
| 50 |
+
tokenizer_padding: bool = True,
|
| 51 |
+
tokenizer_truncation: bool = True,
|
| 52 |
+
tokenizer_max_length: Optional[int] = 1024,
|
| 53 |
+
return_filenames: bool = False
|
| 54 |
+
) -> dict:
|
| 55 |
+
"""
|
| 56 |
+
Applies SigLIP collate and then extends it with tokenization logic.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
samples_list: List of samples containing 'image' and 'report'.
|
| 60 |
+
tokenizer: Tokenizer function to apply on the reports.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
A dictionary with collated images and tokenized text.
|
| 64 |
+
"""
|
| 65 |
+
if MONAI_IMPORT_ERROR is not None:
|
| 66 |
+
raise ImportError(
|
| 67 |
+
"MONAI is required to use extended_collate_siglip but not installed. "
|
| 68 |
+
"Please install MONAI to use this collate function."
|
| 69 |
+
) from MONAI_IMPORT_ERROR
|
| 70 |
+
|
| 71 |
+
collated_data = list_data_collate(samples_list)
|
| 72 |
+
|
| 73 |
+
if return_filenames:
|
| 74 |
+
if "image" in collated_data.keys():
|
| 75 |
+
if (
|
| 76 |
+
hasattr(samples_list[0]["image"].data, "meta")
|
| 77 |
+
and "filename_or_obj" in samples_list[0]["image"].data.meta
|
| 78 |
+
):
|
| 79 |
+
collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
|
| 80 |
+
|
| 81 |
+
if tokenizer is not None and "report" in collated_data.keys():
|
| 82 |
+
tokenizer_output = tokenizer.batch_encode_plus(
|
| 83 |
+
collated_data["report"],
|
| 84 |
+
add_special_tokens=True,
|
| 85 |
+
padding=tokenizer_padding,
|
| 86 |
+
truncation=tokenizer_truncation,
|
| 87 |
+
max_length=tokenizer_max_length,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
collated_data["input_ids"] = torch.tensor(tokenizer_output["input_ids"])
|
| 91 |
+
collated_data["attention_mask"] = torch.tensor(tokenizer_output["attention_mask"])
|
| 92 |
+
|
| 93 |
+
return collated_data
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def collate_add_filenames(samples_list: List) -> dict:
|
| 97 |
+
"""
|
| 98 |
+
Applies MONAI's list_data_collate and adds filenames to the collated output.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
samples_list: List of samples containing 'image' with metadata.
|
| 102 |
+
Returns:
|
| 103 |
+
A dictionary with collated images and filenames.
|
| 104 |
+
"""
|
| 105 |
+
if MONAI_IMPORT_ERROR is not None:
|
| 106 |
+
raise ImportError(
|
| 107 |
+
"MONAI is required to use collate_add_filenames but not installed. "
|
| 108 |
+
"Please install MONAI to use this collate function."
|
| 109 |
+
) from MONAI_IMPORT_ERROR
|
| 110 |
+
|
| 111 |
+
collated_data = list_data_collate(samples_list)
|
| 112 |
+
|
| 113 |
+
if "image" in collated_data.keys():
|
| 114 |
+
if (
|
| 115 |
+
hasattr(samples_list[0]["image"].data, "meta")
|
| 116 |
+
and "filename_or_obj" in samples_list[0]["image"].data.meta
|
| 117 |
+
):
|
| 118 |
+
collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
|
| 119 |
+
|
| 120 |
+
return collated_data
|
spectre/utils/config.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from spectre.utils import _utils, distributed
|
| 5 |
+
|
| 6 |
+
OMEGACONF_IMPORT_ERROR = None
|
| 7 |
+
try:
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
except ImportError as e:
|
| 10 |
+
OmegaConf = None # type: ignore
|
| 11 |
+
OMEGACONF_IMPORT_ERROR = e
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def apply_scaling_rules_to_cfg(cfg):
|
| 15 |
+
"""
|
| 16 |
+
Apply learing rate scaling rules to the configuration object.
|
| 17 |
+
"""
|
| 18 |
+
base_lr = cfg.optim.base_lr
|
| 19 |
+
cfg.optim.lr = base_lr
|
| 20 |
+
|
| 21 |
+
# Apply scaling rules
|
| 22 |
+
if cfg.optim.scaling_rule == "constant":
|
| 23 |
+
return cfg
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
scaling_type, ref_batch_size = cfg.optim.scaling_rule.split("_wrt_")
|
| 27 |
+
ref_batch_size = float(ref_batch_size)
|
| 28 |
+
except ValueError:
|
| 29 |
+
raise NotImplementedError(f"Unknown scaling rule: {cfg.optim.scaling_rule}")
|
| 30 |
+
|
| 31 |
+
scale_factor = cfg.train.batch_size_per_gpu * distributed.get_global_size()
|
| 32 |
+
scale_factor /= ref_batch_size
|
| 33 |
+
scale_factor *= cfg.train.grad_accum_steps
|
| 34 |
+
|
| 35 |
+
if scaling_type == "sqrt":
|
| 36 |
+
cfg.optim.lr *= math.sqrt(scale_factor)
|
| 37 |
+
elif scaling_type == "linear":
|
| 38 |
+
cfg.optim.lr *= scale_factor
|
| 39 |
+
else:
|
| 40 |
+
raise NotImplementedError(f"Unsupported scaling type: {scaling_type}")
|
| 41 |
+
|
| 42 |
+
return cfg
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 46 |
+
if OMEGACONF_IMPORT_ERROR is not None:
|
| 47 |
+
raise ImportError(
|
| 48 |
+
"OmegaConf is required to use write_config but not installed. "
|
| 49 |
+
"Please install OmegaConf to use this function."
|
| 50 |
+
) from OMEGACONF_IMPORT_ERROR
|
| 51 |
+
|
| 52 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 53 |
+
with open(saved_cfg_path, "w") as f:
|
| 54 |
+
OmegaConf.save(config=cfg, f=f)
|
| 55 |
+
return saved_cfg_path
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_cfg_from_args(args, default_config):
|
| 59 |
+
if OMEGACONF_IMPORT_ERROR is not None:
|
| 60 |
+
raise ImportError(
|
| 61 |
+
"OmegaConf is required to use get_cfg_from_args but not installed. "
|
| 62 |
+
"Please install OmegaConf to use this function."
|
| 63 |
+
) from OMEGACONF_IMPORT_ERROR
|
| 64 |
+
|
| 65 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 66 |
+
args.opts = [] if args.opts is None else args.opts
|
| 67 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 68 |
+
default_cfg = OmegaConf.create(default_config)
|
| 69 |
+
cfg = OmegaConf.load(args.config_file)
|
| 70 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 71 |
+
return cfg
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def random_seed(args):
|
| 75 |
+
seed = getattr(args, "seed", 0)
|
| 76 |
+
rank = distributed.get_global_rank()
|
| 77 |
+
|
| 78 |
+
_utils.fix_random_seeds(seed + rank)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def setup(args, default_config):
|
| 82 |
+
"""
|
| 83 |
+
Create configs and perform basic setups.
|
| 84 |
+
"""
|
| 85 |
+
cfg = get_cfg_from_args(args, default_config)
|
| 86 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 87 |
+
random_seed(args)
|
| 88 |
+
accelerator = distributed.init_distributed(cfg)
|
| 89 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 90 |
+
write_config(cfg, args.output_dir)
|
| 91 |
+
return cfg, accelerator
|
spectre/utils/dataloader.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
from typing import Union, Callable, Optional, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import ConcatDataset
|
| 7 |
+
|
| 8 |
+
MONAI_IMPORT_ERROR = None
|
| 9 |
+
try:
|
| 10 |
+
import monai.data as data
|
| 11 |
+
except ImportError as e:
|
| 12 |
+
data = None # type: ignore
|
| 13 |
+
MONAI_IMPORT_ERROR = e
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_dataloader(
|
| 18 |
+
datasets: Union[str, List[str]],
|
| 19 |
+
data_dir: str,
|
| 20 |
+
include_reports: bool = False,
|
| 21 |
+
include_labels: bool = False,
|
| 22 |
+
cache_dataset: bool = False,
|
| 23 |
+
cache_dir: Optional[str] = None,
|
| 24 |
+
use_gds: bool = False,
|
| 25 |
+
transform: Optional[Callable] = None,
|
| 26 |
+
fraction: float = 1.0,
|
| 27 |
+
batch_size: int = 64,
|
| 28 |
+
num_workers: int = 4,
|
| 29 |
+
pin_memory: bool = True,
|
| 30 |
+
shuffle: bool = True,
|
| 31 |
+
collate_fn: Optional[Callable] = None,
|
| 32 |
+
drop_last: bool = True,
|
| 33 |
+
persistent_workers: bool = True,
|
| 34 |
+
use_thread: bool = False,
|
| 35 |
+
) -> "DataLoader":
|
| 36 |
+
"""
|
| 37 |
+
Get dataloader for training.
|
| 38 |
+
"""
|
| 39 |
+
if MONAI_IMPORT_ERROR is not None:
|
| 40 |
+
raise ImportError(
|
| 41 |
+
"MONAI is required to use get_dataloader but not installed. "
|
| 42 |
+
"Please install MONAI to use this function."
|
| 43 |
+
) from MONAI_IMPORT_ERROR
|
| 44 |
+
|
| 45 |
+
if isinstance(datasets, str):
|
| 46 |
+
datasets = [datasets]
|
| 47 |
+
|
| 48 |
+
# Validate constraints
|
| 49 |
+
if include_reports:
|
| 50 |
+
assert set(datasets).issubset({"ct_rate", "merlin", "inspect"}), \
|
| 51 |
+
"When include_reports=True, only 'ct_rate', 'merlin', and 'inspect' are allowed."
|
| 52 |
+
if include_labels:
|
| 53 |
+
assert set(datasets).issubset({"abdomen_atlas", "abdomenct_1k"}), \
|
| 54 |
+
"When include_labels=True, only 'abdomen_atlas' and 'abdomenct_1k' are allowed."
|
| 55 |
+
if use_gds:
|
| 56 |
+
assert cache_dataset, "GDS requires cache_dataset=True."
|
| 57 |
+
assert torch.cuda.is_available(), "GDS requires CUDA to be available."
|
| 58 |
+
|
| 59 |
+
# Dataset configurations
|
| 60 |
+
DATASET_CONFIGS = {
|
| 61 |
+
"ct_rate": {"folder": "CT-RATE", "base_name": "CTRate",
|
| 62 |
+
"extra": {"include_reports": include_reports}},
|
| 63 |
+
"inspect": {"folder": "INSPECT", "base_name": "Inspect",
|
| 64 |
+
"extra": {"include_reports": include_reports}},
|
| 65 |
+
"merlin": {"folder": "MERLIN", "base_name": "Merlin",
|
| 66 |
+
"extra": {"include_reports": include_reports}},
|
| 67 |
+
"nlst": {"folder": "NLST", "base_name": "Nlst"},
|
| 68 |
+
"amos": {"folder": "Amos", "base_name": "Amos"},
|
| 69 |
+
"abdomen_atlas": {"folder": "AbdomenAtlas1.0Mini", "base_name": "AbdomenAtlas",
|
| 70 |
+
"extra": {"include_labels": include_labels}},
|
| 71 |
+
"panorama": {"folder": "PANORAMA", "base_name": "Panorama"},
|
| 72 |
+
"abdomenct_1k": {"folder": "AbdomenCT-1K", "base_name": "AbdomenCT1K",
|
| 73 |
+
"extra": {"include_labels": include_labels}},
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
datasets_list = []
|
| 77 |
+
for ds in datasets:
|
| 78 |
+
if ds.lower() not in DATASET_CONFIGS:
|
| 79 |
+
raise NotImplementedError(f"Dataset {ds} not implemented.")
|
| 80 |
+
|
| 81 |
+
cfg = DATASET_CONFIGS[ds.lower()]
|
| 82 |
+
folder = cfg["folder"]
|
| 83 |
+
extra_args = cfg.get("extra", {})
|
| 84 |
+
|
| 85 |
+
kwargs = {
|
| 86 |
+
"data_dir": os.path.join(data_dir, folder),
|
| 87 |
+
"transform": transform,
|
| 88 |
+
"fraction": fraction,
|
| 89 |
+
**extra_args,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
base_name = cfg["base_name"]
|
| 93 |
+
class_suffix = "Dataset"
|
| 94 |
+
if cache_dataset:
|
| 95 |
+
class_suffix = "GDSDataset" if use_gds else "PersistentDataset"
|
| 96 |
+
|
| 97 |
+
class_name = f"{base_name}{class_suffix}"
|
| 98 |
+
DatasetClass = getattr(__import__("spectre.data", fromlist=[class_name]), class_name)
|
| 99 |
+
|
| 100 |
+
if cache_dataset:
|
| 101 |
+
kwargs["cache_dir"] = os.path.join(cache_dir, folder)
|
| 102 |
+
if use_gds:
|
| 103 |
+
kwargs["device"] = torch.cuda.current_device()
|
| 104 |
+
|
| 105 |
+
datasets_list.append(DatasetClass(**kwargs))
|
| 106 |
+
|
| 107 |
+
dataset = datasets_list[0] if len(datasets_list) == 1 else ConcatDataset(datasets_list)
|
| 108 |
+
|
| 109 |
+
loader_cls = getattr(data, "ThreadDataLoader" if use_thread else "DataLoader")
|
| 110 |
+
loader_kwargs = {
|
| 111 |
+
"dataset": dataset,
|
| 112 |
+
"batch_size": batch_size,
|
| 113 |
+
"num_workers": num_workers,
|
| 114 |
+
"shuffle": shuffle,
|
| 115 |
+
"drop_last": drop_last,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
if not use_thread:
|
| 119 |
+
loader_kwargs.update({
|
| 120 |
+
"pin_memory": pin_memory,
|
| 121 |
+
"persistent_workers": persistent_workers
|
| 122 |
+
})
|
| 123 |
+
if collate_fn is not None:
|
| 124 |
+
loader_kwargs["collate_fn"] = collate_fn
|
| 125 |
+
|
| 126 |
+
return loader_cls(**loader_kwargs)
|
spectre/utils/distributed.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
|
| 5 |
+
ACCELERATE_IMPORT_ERROR = None
|
| 6 |
+
try:
|
| 7 |
+
from accelerate import Accelerator, DataLoaderConfiguration
|
| 8 |
+
except ImportError as e:
|
| 9 |
+
Accelerator = None # type: ignore
|
| 10 |
+
DataLoaderConfiguration = None # type: ignore
|
| 11 |
+
ACCELERATE_IMPORT_ERROR = e
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def is_enabled() -> bool:
|
| 15 |
+
"""
|
| 16 |
+
Returns:
|
| 17 |
+
True if distributed training is enabled
|
| 18 |
+
"""
|
| 19 |
+
return dist.is_available() and dist.is_initialized()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_global_size() -> int:
|
| 23 |
+
"""
|
| 24 |
+
Returns:
|
| 25 |
+
Number of processes in the distributed group
|
| 26 |
+
"""
|
| 27 |
+
if not is_enabled():
|
| 28 |
+
return 1
|
| 29 |
+
return dist.get_world_size()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_global_rank() -> int:
|
| 33 |
+
"""
|
| 34 |
+
Returns:
|
| 35 |
+
The rank of the current process in the distributed group
|
| 36 |
+
"""
|
| 37 |
+
if not is_enabled():
|
| 38 |
+
return 0
|
| 39 |
+
return dist.get_rank()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_local_size() -> int:
|
| 43 |
+
"""
|
| 44 |
+
Returns:
|
| 45 |
+
Number of processes on the current machine
|
| 46 |
+
"""
|
| 47 |
+
if not is_enabled():
|
| 48 |
+
return 1
|
| 49 |
+
return int(os.environ.get("LOCAL_SIZE", 1))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_local_rank() -> int:
|
| 53 |
+
"""
|
| 54 |
+
Returns:
|
| 55 |
+
The rank of the current process on the current machine
|
| 56 |
+
"""
|
| 57 |
+
if not is_enabled():
|
| 58 |
+
return 0
|
| 59 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def init_distributed(cfg):
|
| 63 |
+
"""
|
| 64 |
+
Initialize distributed training.
|
| 65 |
+
"""
|
| 66 |
+
if ACCELERATE_IMPORT_ERROR is not None:
|
| 67 |
+
raise ImportError(
|
| 68 |
+
"Accelerate is required to use init_distributed but not installed. "
|
| 69 |
+
"Please install Accelerate to use this function."
|
| 70 |
+
) from ACCELERATE_IMPORT_ERROR
|
| 71 |
+
|
| 72 |
+
# Initialize accelerator
|
| 73 |
+
dataloader_config = DataLoaderConfiguration(
|
| 74 |
+
non_blocking=cfg.train.pin_memory,
|
| 75 |
+
)
|
| 76 |
+
accelerator = Accelerator(
|
| 77 |
+
gradient_accumulation_steps=cfg.train.grad_accum_steps,
|
| 78 |
+
log_with="wandb" if cfg.train.log_wandb else None,
|
| 79 |
+
dataloader_config=dataloader_config,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Initialize wandb
|
| 83 |
+
if cfg.train.log_wandb:
|
| 84 |
+
accelerator.init_trackers(
|
| 85 |
+
project_name="spectre",
|
| 86 |
+
config={k: v for d in cfg.values() for k, v in d.items()},
|
| 87 |
+
init_kwargs={
|
| 88 |
+
"dir": os.path.join(cfg.train.output_dir, "logs"),
|
| 89 |
+
},
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return accelerator
|
spectre/utils/lora.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import loralib as lora
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def add_lora_adapters(
|
| 6 |
+
root_module: nn.Module,
|
| 7 |
+
r: int = 8,
|
| 8 |
+
lora_alpha: int = 32,
|
| 9 |
+
lora_dropout: float = 0.05,
|
| 10 |
+
target_keywords: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
|
| 11 |
+
) -> None:
|
| 12 |
+
"""
|
| 13 |
+
Recursively traverses the model and replaces every `nn.Linear`
|
| 14 |
+
whose name contains one of `target_keywords` with a LoRA-augmented
|
| 15 |
+
linear layer from loralib.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
for name, child in list(root_module.named_children()):
|
| 19 |
+
# If the child is itself a container, recurse first
|
| 20 |
+
add_lora_adapters(child, r, lora_alpha, lora_dropout, target_keywords)
|
| 21 |
+
|
| 22 |
+
# Replace target linear layers
|
| 23 |
+
if isinstance(child, nn.Linear) and any(k in name for k in target_keywords):
|
| 24 |
+
lora_layer = lora.Linear( # loralib wrapper
|
| 25 |
+
in_features=child.in_features,
|
| 26 |
+
out_features=child.out_features,
|
| 27 |
+
r=r,
|
| 28 |
+
lora_alpha=lora_alpha,
|
| 29 |
+
lora_dropout=lora_dropout,
|
| 30 |
+
bias=child.bias is not None,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# copy original weights so that behaviour is identical pre-training
|
| 34 |
+
lora_layer.weight.data = child.weight.data.clone()
|
| 35 |
+
if child.bias is not None:
|
| 36 |
+
lora_layer.bias.data = child.bias.data.clone()
|
| 37 |
+
|
| 38 |
+
setattr(root_module, name, lora_layer) # hot-swap!
|
spectre/utils/masking.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _random_block_mask(
|
| 8 |
+
size: Tuple[int, int, int],
|
| 9 |
+
num_masks: int,
|
| 10 |
+
min_num_masks_per_block: int = 4,
|
| 11 |
+
max_num_masks_per_block: Optional[int] = None,
|
| 12 |
+
max_attempts_per_block: int = 10,
|
| 13 |
+
generator: Optional[torch.Generator] = None,
|
| 14 |
+
device: Optional[Union[torch.device, str]] = None,
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
"""3D helper: generate a (H, W, D) boolean mask by placing cuboidal blocks.
|
| 17 |
+
|
| 18 |
+
- size: (H, W, D)
|
| 19 |
+
- num_masks: target total number of masked voxels for this image
|
| 20 |
+
- min_num_masks_per_block / max_num_masks_per_block: voxel-range per block
|
| 21 |
+
"""
|
| 22 |
+
H, W, D = size
|
| 23 |
+
total = H * W * D
|
| 24 |
+
num_masks = min(max(0, int(num_masks)), total)
|
| 25 |
+
|
| 26 |
+
if max_num_masks_per_block is None:
|
| 27 |
+
max_num_masks_per_block = max(1, num_masks)
|
| 28 |
+
|
| 29 |
+
mask = torch.zeros((H, W, D), dtype=torch.bool, device=device)
|
| 30 |
+
masked_count = 0
|
| 31 |
+
global_attempts = 0
|
| 32 |
+
|
| 33 |
+
orders = [(0, 1, 2), (1, 2, 0), (2, 0, 1)]
|
| 34 |
+
|
| 35 |
+
# Try to place blocks until we have enough masked voxels or we exceed attempts
|
| 36 |
+
while masked_count < num_masks and global_attempts < max_attempts_per_block:
|
| 37 |
+
global_attempts += 1
|
| 38 |
+
|
| 39 |
+
# choose target voxels for this block
|
| 40 |
+
target_voxels = int(torch.randint(
|
| 41 |
+
min_num_masks_per_block, max_num_masks_per_block + 1, (1,), generator=generator
|
| 42 |
+
).item())
|
| 43 |
+
|
| 44 |
+
found = False
|
| 45 |
+
local_attempts = 0
|
| 46 |
+
while not found and local_attempts < max_attempts_per_block:
|
| 47 |
+
local_attempts += 1
|
| 48 |
+
|
| 49 |
+
# random pick order for dims to reduce bias
|
| 50 |
+
order_idx = int(torch.randint(0, 3, (1,), generator=generator).item())
|
| 51 |
+
order = orders[order_idx]
|
| 52 |
+
|
| 53 |
+
# pick first dimension
|
| 54 |
+
if order[0] == 0:
|
| 55 |
+
h = int(torch.randint(1, min(H, target_voxels) + 1, (1,), generator=generator).item())
|
| 56 |
+
elif order[0] == 1:
|
| 57 |
+
w = int(torch.randint(1, min(W, target_voxels) + 1, (1,), generator=generator).item())
|
| 58 |
+
else:
|
| 59 |
+
d = int(torch.randint(1, min(D, target_voxels) + 1, (1,), generator=generator).item())
|
| 60 |
+
|
| 61 |
+
# progressively choose remaining dims while ensuring feasibility
|
| 62 |
+
try:
|
| 63 |
+
if order[0] == 0:
|
| 64 |
+
# h chosen -> pick w then compute d_needed
|
| 65 |
+
max_w = max(1, min(W, target_voxels // h))
|
| 66 |
+
w = int(torch.randint(1, max_w + 1, (1,), generator=generator).item())
|
| 67 |
+
d_needed = math.ceil(target_voxels / (h * w))
|
| 68 |
+
if d_needed <= D:
|
| 69 |
+
d = max(1, d_needed)
|
| 70 |
+
found = True
|
| 71 |
+
elif order[0] == 1:
|
| 72 |
+
# w chosen -> pick d then compute h_needed
|
| 73 |
+
max_d = max(1, min(D, target_voxels // w))
|
| 74 |
+
d = int(torch.randint(1, max_d + 1, (1,), generator=generator).item())
|
| 75 |
+
h_needed = math.ceil(target_voxels / (d * w))
|
| 76 |
+
if h_needed <= H:
|
| 77 |
+
h = max(1, h_needed)
|
| 78 |
+
found = True
|
| 79 |
+
else:
|
| 80 |
+
# d chosen -> pick h then compute w_needed
|
| 81 |
+
max_h = max(1, min(H, target_voxels // d))
|
| 82 |
+
h = int(torch.randint(1, max_h + 1, (1,), generator=generator).item())
|
| 83 |
+
w_needed = math.ceil(target_voxels / (d * h))
|
| 84 |
+
if w_needed <= W:
|
| 85 |
+
w = max(1, w_needed)
|
| 86 |
+
found = True
|
| 87 |
+
|
| 88 |
+
except ValueError:
|
| 89 |
+
# in case of invalid ranges (defensive); just continue trying
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
# fallback alternative attempt: try simple factorization heuristics
|
| 93 |
+
if not found:
|
| 94 |
+
# attempt small-to-large factorization
|
| 95 |
+
for hh in range(1, min(H, target_voxels) + 1):
|
| 96 |
+
for ww in range(1, min(W, target_voxels // hh) + 1):
|
| 97 |
+
dd = math.ceil(target_voxels / (hh * ww))
|
| 98 |
+
if dd <= D:
|
| 99 |
+
h, w, d = hh, ww, dd
|
| 100 |
+
found = True
|
| 101 |
+
break
|
| 102 |
+
if found:
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
if not found:
|
| 106 |
+
# couldn't find a fitting block this global attempt; move on
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
# clamp block dims to volume just in case and ensure at least 1
|
| 110 |
+
h = min(max(1, int(h)), H)
|
| 111 |
+
w = min(max(1, int(w)), W)
|
| 112 |
+
d = min(max(1, int(d)), D)
|
| 113 |
+
|
| 114 |
+
# choose random location so block fits
|
| 115 |
+
x0 = int(torch.randint(0, (H - h) + 1, (1,), generator=generator).item()) if H - h > 0 else 0
|
| 116 |
+
y0 = int(torch.randint(0, (W - w) + 1, (1,), generator=generator).item()) if W - w > 0 else 0
|
| 117 |
+
z0 = int(torch.randint(0, (D - d) + 1, (1,), generator=generator).item()) if D - d > 0 else 0
|
| 118 |
+
|
| 119 |
+
mask[x0 : x0 + h, y0 : y0 + w, z0 : z0 + d] = True
|
| 120 |
+
masked_count = int(mask.sum().item())
|
| 121 |
+
|
| 122 |
+
# If still short, fill remaining voxels at random positions
|
| 123 |
+
if masked_count < num_masks:
|
| 124 |
+
remaining = num_masks - masked_count
|
| 125 |
+
indices = torch.nonzero(~mask, as_tuple=False)
|
| 126 |
+
if indices.numel() > 0:
|
| 127 |
+
perm = torch.randperm(indices.shape[0], generator=generator, device=mask.device)
|
| 128 |
+
pick = indices[perm[:remaining]]
|
| 129 |
+
mask[pick[:, 0], pick[:, 1], pick[:, 2]] = True
|
| 130 |
+
|
| 131 |
+
return mask
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def random_block_mask(
|
| 135 |
+
size: Tuple[int, int, int, int],
|
| 136 |
+
batch_mask_ratio: float = 0.5,
|
| 137 |
+
min_image_mask_ratio: float = 0.1,
|
| 138 |
+
max_image_mask_ratio: float = 0.5,
|
| 139 |
+
min_num_masks_per_block: int = 4,
|
| 140 |
+
max_num_masks_per_block: Optional[int] = None,
|
| 141 |
+
max_attempts_per_block: int = 10,
|
| 142 |
+
generator: Optional[torch.Generator] = None,
|
| 143 |
+
device: Optional[Union[torch.device, str]] = None,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""Create random block masks for 3D volumes only.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
size: (B, H, W, D)
|
| 149 |
+
batch_mask_ratio: fraction of images in the batch to apply masking to
|
| 150 |
+
min_image_mask_ratio / max_image_mask_ratio: per-image mask fraction range
|
| 151 |
+
min_num_masks_per_block / max_num_masks_per_block: voxels per block range
|
| 152 |
+
max_attempts_per_block: attempts to find a fitting block
|
| 153 |
+
generator: optional torch.Generator for reproducibility.
|
| 154 |
+
device: device for returned tensor
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
boolean tensor with shape (B, H, W, D)
|
| 158 |
+
"""
|
| 159 |
+
if len(size) != 4:
|
| 160 |
+
raise ValueError("size must be (B, H, W, D) for 3D masking.")
|
| 161 |
+
|
| 162 |
+
B, H, W, D = size
|
| 163 |
+
|
| 164 |
+
if max_image_mask_ratio < min_image_mask_ratio:
|
| 165 |
+
raise ValueError("max_image_mask_ratio must be >= min_image_mask_ratio.")
|
| 166 |
+
|
| 167 |
+
num_images_masked = int(B * batch_mask_ratio)
|
| 168 |
+
probs = torch.linspace(min_image_mask_ratio, max_image_mask_ratio, num_images_masked + 1).tolist()
|
| 169 |
+
|
| 170 |
+
image_masks = []
|
| 171 |
+
total_voxels = H * W * D
|
| 172 |
+
|
| 173 |
+
for prob_min, prob_max in zip(probs[:-1], probs[1:]):
|
| 174 |
+
# choose number of masked voxels for this image
|
| 175 |
+
u = float(prob_min + (prob_max - prob_min) * torch.rand(1, generator=generator).item())
|
| 176 |
+
num_mask = int(total_voxels * u)
|
| 177 |
+
image_masks.append(
|
| 178 |
+
_random_block_mask(
|
| 179 |
+
size=(H, W, D),
|
| 180 |
+
num_masks=num_mask,
|
| 181 |
+
min_num_masks_per_block=min_num_masks_per_block,
|
| 182 |
+
max_num_masks_per_block=max_num_masks_per_block,
|
| 183 |
+
max_attempts_per_block=max_attempts_per_block,
|
| 184 |
+
generator=generator,
|
| 185 |
+
device=device,
|
| 186 |
+
)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Add non-masked images (all False) to fill the batch
|
| 190 |
+
for _ in range(num_images_masked, B):
|
| 191 |
+
image_masks.append(torch.zeros((H, W, D), dtype=torch.bool, device=device))
|
| 192 |
+
|
| 193 |
+
perm = torch.randperm(B, generator=generator).tolist()
|
| 194 |
+
image_masks = [image_masks[i] for i in perm]
|
| 195 |
+
|
| 196 |
+
return torch.stack(image_masks)
|
spectre/utils/modeling.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import List, Tuple, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def deactivate_requires_grad_and_to_eval(model: nn.Module):
|
| 14 |
+
"""Deactivates the requires_grad flag for all parameters of a model.
|
| 15 |
+
|
| 16 |
+
This has the same effect as permanently executing the model within a `torch.no_grad()`
|
| 17 |
+
context. Use this method to disable gradient computation and therefore
|
| 18 |
+
training for a model.
|
| 19 |
+
|
| 20 |
+
Examples:
|
| 21 |
+
>>> backbone = resnet18()
|
| 22 |
+
>>> deactivate_requires_grad(backbone)
|
| 23 |
+
"""
|
| 24 |
+
for param in model.parameters():
|
| 25 |
+
param.requires_grad = False
|
| 26 |
+
model.eval()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def activate_requires_grad_and_to_train(model: nn.Module):
|
| 30 |
+
"""Activates the requires_grad flag for all parameters of a model.
|
| 31 |
+
|
| 32 |
+
Use this method to activate gradients for a model (e.g. after deactivating
|
| 33 |
+
them using `deactivate_requires_grad(...)`).
|
| 34 |
+
|
| 35 |
+
Examples:
|
| 36 |
+
>>> backbone = resnet18()
|
| 37 |
+
>>> activate_requires_grad(backbone)
|
| 38 |
+
"""
|
| 39 |
+
for param in model.parameters():
|
| 40 |
+
param.requires_grad = True
|
| 41 |
+
model.train()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
|
| 46 |
+
"""Updates parameters of `model_ema` with Exponential Moving Average of `model`
|
| 47 |
+
|
| 48 |
+
Momentum encoders are a crucial component fo models such as MoCo or BYOL.
|
| 49 |
+
|
| 50 |
+
Examples:
|
| 51 |
+
>>> backbone = resnet18()
|
| 52 |
+
>>> projection_head = MoCoProjectionHead()
|
| 53 |
+
>>> backbone_momentum = copy.deepcopy(moco)
|
| 54 |
+
>>> projection_head_momentum = copy.deepcopy(projection_head)
|
| 55 |
+
>>>
|
| 56 |
+
>>> # update momentum
|
| 57 |
+
>>> update_momentum(moco, moco_momentum, m=0.999)
|
| 58 |
+
>>> update_momentum(projection_head, projection_head_momentum, m=0.999)
|
| 59 |
+
"""
|
| 60 |
+
for model_ema, model in zip(model_ema.parameters(), model.parameters()):
|
| 61 |
+
model_ema.data = model_ema.data * m + model.data * (1.0 - m)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def update_drop_path_rate(
|
| 65 |
+
model: "VisionTransformer",
|
| 66 |
+
drop_path_rate: float,
|
| 67 |
+
mode: str = "linear",
|
| 68 |
+
) -> None:
|
| 69 |
+
"""Updates the drop path rate in a VisionTransformer model.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
model:
|
| 73 |
+
VisionTransformer model.
|
| 74 |
+
drop_path_rate:
|
| 75 |
+
Maximum drop path rate.
|
| 76 |
+
mode:
|
| 77 |
+
Drop path rate update mode. Can be "linear" or "uniform". Linear increases
|
| 78 |
+
the drop path rate from 0 to drop_path_rate over the depth of the model.
|
| 79 |
+
Uniform sets the drop path rate to drop_path_rate for all blocks.
|
| 80 |
+
Raises:
|
| 81 |
+
ValueError: If an unknown mode is provided.
|
| 82 |
+
"""
|
| 83 |
+
from timm.layers import DropPath
|
| 84 |
+
|
| 85 |
+
total_depth = len(model.blocks)
|
| 86 |
+
|
| 87 |
+
# Determine drop path rates based on the specified mode
|
| 88 |
+
if mode == "linear":
|
| 89 |
+
drop_probabilities = np.linspace(0, drop_path_rate, total_depth)
|
| 90 |
+
elif mode == "uniform":
|
| 91 |
+
drop_probabilities = [drop_path_rate for _ in range(total_depth)]
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"Unknown mode: '{mode}', supported modes are 'linear' and 'uniform'."
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Update the drop path rate for each block in the model
|
| 98 |
+
for block, drop_prob in zip(model.blocks, drop_probabilities):
|
| 99 |
+
if drop_prob > 0.0:
|
| 100 |
+
block.drop_path1 = DropPath(drop_prob=drop_prob)
|
| 101 |
+
block.drop_path2 = DropPath(drop_prob=drop_prob)
|
| 102 |
+
else:
|
| 103 |
+
block.drop_path1 = nn.Identity()
|
| 104 |
+
block.drop_path2 = nn.Identity()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
|
| 108 |
+
"""Repeats a token size times.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
token: Token tensor with shape (1, 1, dim).
|
| 112 |
+
size: (batch_size, sequence_length) tuple.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Tensor with shape (batch_size, sequence_length, dim) containing copies
|
| 116 |
+
of the input token.
|
| 117 |
+
"""
|
| 118 |
+
batch_size, sequence_length = size
|
| 119 |
+
return token.repeat(batch_size, sequence_length, 1)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
"""Expands the index along the last dimension of the input tokens.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
index:
|
| 127 |
+
Index tensor with shape (batch_size, idx_length) where each entry is
|
| 128 |
+
an index in [0, sequence_length).
|
| 129 |
+
tokens:
|
| 130 |
+
Tokens tensor with shape (batch_size, sequence_length, dim).
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Index tensor with shape (batch_size, idx_length, dim) where the original
|
| 134 |
+
indices are repeated dim times along the last dimension.
|
| 135 |
+
"""
|
| 136 |
+
dim = tokens.shape[-1]
|
| 137 |
+
index = index.unsqueeze(-1).expand(-1, -1, dim)
|
| 138 |
+
return index
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_at_index(tokens: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
|
| 142 |
+
"""Selects tokens at index.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
tokens:
|
| 146 |
+
Token tensor with shape (batch_size, sequence_length, dim).
|
| 147 |
+
index:
|
| 148 |
+
Index tensor with shape (batch_size, index_length) where each entry is
|
| 149 |
+
an index in [0, sequence_length).
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Token tensor with shape (batch_size, index_length, dim) containing the
|
| 153 |
+
selected tokens.
|
| 154 |
+
"""
|
| 155 |
+
index = expand_index_like(index, tokens)
|
| 156 |
+
return torch.gather(tokens, 1, index)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def set_at_index(
|
| 160 |
+
tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor
|
| 161 |
+
) -> torch.Tensor:
|
| 162 |
+
"""Copies all values into the input tensor at the given indices.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
tokens: Tokens tensor with shape (batch_size, sequence_length, dim).
|
| 166 |
+
index: Index tensor with shape (batch_size, index_length).
|
| 167 |
+
value: Value tensor with shape (batch_size, index_length, dim).
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Tokens tensor with shape (batch_size, sequence_length, dim) containing
|
| 171 |
+
the new values.
|
| 172 |
+
"""
|
| 173 |
+
index = expand_index_like(index, tokens)
|
| 174 |
+
return torch.scatter(tokens, 1, index, value)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def mask_at_index(
|
| 178 |
+
tokens: torch.Tensor, index: torch.Tensor, mask_token: torch.Tensor
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
"""Copies mask token into the input tensor at the given indices.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
tokens:
|
| 184 |
+
Tokens tensor with shape (batch_size, sequence_length, dim).
|
| 185 |
+
index:
|
| 186 |
+
Index tensor with shape (batch_size, index_length).
|
| 187 |
+
mask_token:
|
| 188 |
+
Value tensor with shape (1, 1, dim).
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Tokens tensor with shape (batch_size, sequence_length, dim) containing
|
| 192 |
+
the new values.
|
| 193 |
+
|
| 194 |
+
"""
|
| 195 |
+
mask = tokens.new_zeros(tokens.shape)
|
| 196 |
+
mask = set_at_index(mask, index, 1)
|
| 197 |
+
return (1 - mask) * tokens + mask * mask_token
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def mask_bool(tokens: torch.Tensor, mask: torch.Tensor, mask_token: torch.Tensor) -> torch. Tensor:
|
| 201 |
+
"""Returns a tensor with tokens replaced by the mask tokens in all positions where
|
| 202 |
+
the mask is True.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
tokens:
|
| 206 |
+
Tokens tensor with shape (batch_size, sequence_length, dim).
|
| 207 |
+
mask:
|
| 208 |
+
Boolean mask tensor with shape (batch_size, sequence_length).
|
| 209 |
+
mask_token:
|
| 210 |
+
Mask token with shape (1, 1, dim).
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Tokens tensor with shape (batch_size, sequence_length, dim) where tokens[i, j]
|
| 214 |
+
is replaced by the mask token if mask[i, j] is True.
|
| 215 |
+
"""
|
| 216 |
+
# Convert to int for multiplication.
|
| 217 |
+
mask = mask.unsqueeze(-1).to(torch.bool).to(torch.int)
|
| 218 |
+
return (1 - mask) * tokens + mask * mask_token
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def patchify(images: torch.Tensor, patch_size: Tuple[int, int, int]) -> torch.Tensor:
|
| 222 |
+
"""Converts a batch of input images into patches.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
images:
|
| 226 |
+
Images tensor with shape (batch_size, channels, height, width, depth)
|
| 227 |
+
patch_size:
|
| 228 |
+
Patch size in pixels. Image width and height must be multiples of
|
| 229 |
+
the patch size.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Patches tensor with shape (batch_size, num_patches, channels * math.prod(patch_size))
|
| 233 |
+
where num_patches = image_width / patch_size * image_height / patch_size.
|
| 234 |
+
|
| 235 |
+
"""
|
| 236 |
+
N, C, H, W, D = images.shape
|
| 237 |
+
assert (
|
| 238 |
+
H % patch_size[0] == 0
|
| 239 |
+
and W % patch_size[1] == 0
|
| 240 |
+
and D % patch_size[2] == 0
|
| 241 |
+
), "Image height, width, and depth must be multiples of the patch size."
|
| 242 |
+
|
| 243 |
+
patch_h = H // patch_size[0]
|
| 244 |
+
patch_w = W // patch_size[1]
|
| 245 |
+
patch_d = D // patch_size[2]
|
| 246 |
+
|
| 247 |
+
num_patches = patch_h * patch_w * patch_d
|
| 248 |
+
patches = images.reshape(shape=(
|
| 249 |
+
N, C,
|
| 250 |
+
patch_h, patch_size[0],
|
| 251 |
+
patch_w, patch_size[1],
|
| 252 |
+
patch_d, patch_size[2],
|
| 253 |
+
))
|
| 254 |
+
patches = torch.einsum("nchpwqdr->nhwdpqrc", patches)
|
| 255 |
+
patches = patches.reshape(shape=(N, num_patches, math.prod(patch_size) * C))
|
| 256 |
+
return patches
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def random_token_mask(
|
| 260 |
+
size: Tuple[int, int],
|
| 261 |
+
mask_ratio: float = 0.6,
|
| 262 |
+
mask_class_token: bool = False,
|
| 263 |
+
device: Optional[Union[torch.device, str]] = None,
|
| 264 |
+
) -> torch.Tensor:
|
| 265 |
+
"""Creates random token masks.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
size:
|
| 269 |
+
Size of the token batch for which to generate masks.
|
| 270 |
+
Should be (batch_size, sequence_length).
|
| 271 |
+
mask_ratio:
|
| 272 |
+
Percentage of tokens to mask.
|
| 273 |
+
mask_class_token:
|
| 274 |
+
If False the class token is never masked. If True the class token
|
| 275 |
+
might be masked.
|
| 276 |
+
device:
|
| 277 |
+
Device on which to create the index masks.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
A (index_keep, index_mask) tuple where each index is a tensor.
|
| 281 |
+
index_keep contains the indices of the unmasked tokens and has shape
|
| 282 |
+
(batch_size, num_keep). index_mask contains the indices of the masked
|
| 283 |
+
tokens and has shape (batch_size, sequence_length - num_keep).
|
| 284 |
+
num_keep is equal to sequence_length * (1- mask_ratio).
|
| 285 |
+
|
| 286 |
+
"""
|
| 287 |
+
batch_size, sequence_length = size
|
| 288 |
+
num_keep = int(sequence_length * (1 - mask_ratio))
|
| 289 |
+
|
| 290 |
+
noise = torch.rand(batch_size, sequence_length, device=device)
|
| 291 |
+
if not mask_class_token and sequence_length > 0:
|
| 292 |
+
# make sure that class token is not masked
|
| 293 |
+
noise[:, 0] = -1
|
| 294 |
+
num_keep = max(1, num_keep)
|
| 295 |
+
|
| 296 |
+
# get indices of tokens to keep
|
| 297 |
+
indices = torch.argsort(noise, dim=1)
|
| 298 |
+
idx_keep = indices[:, :num_keep]
|
| 299 |
+
idx_mask = indices[:, num_keep:]
|
| 300 |
+
|
| 301 |
+
return idx_keep, idx_mask
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def resample_abs_pos_embed(
|
| 305 |
+
posemb: torch.Tensor,
|
| 306 |
+
new_size: List[int],
|
| 307 |
+
old_size: List[int],
|
| 308 |
+
num_prefix_tokens: int = 1,
|
| 309 |
+
interpolation: str = 'trilinear',
|
| 310 |
+
):
|
| 311 |
+
# sort out sizes, assume square if old size not provided
|
| 312 |
+
num_pos_tokens = posemb.shape[1]
|
| 313 |
+
num_new_tokens = new_size[0] * new_size[1] * new_size[2] + num_prefix_tokens
|
| 314 |
+
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
|
| 315 |
+
return posemb
|
| 316 |
+
|
| 317 |
+
if num_prefix_tokens:
|
| 318 |
+
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
|
| 319 |
+
else:
|
| 320 |
+
posemb_prefix, posemb = None, posemb
|
| 321 |
+
|
| 322 |
+
# do the interpolation
|
| 323 |
+
embed_dim = posemb.shape[-1]
|
| 324 |
+
orig_dtype = posemb.dtype
|
| 325 |
+
posemb = posemb.float() # interpolate needs float32
|
| 326 |
+
posemb = posemb.reshape(1, old_size[0], old_size[1], old_size[2], -1).permute(0, 4, 1, 2, 3)
|
| 327 |
+
posemb = F.interpolate(posemb, size=new_size, mode=interpolation)
|
| 328 |
+
posemb = posemb.permute(0, 2, 3, 4, 1).reshape(1, -1, embed_dim)
|
| 329 |
+
posemb = posemb.to(orig_dtype)
|
| 330 |
+
|
| 331 |
+
# add back extra (class, etc) prefix tokens
|
| 332 |
+
if posemb_prefix is not None:
|
| 333 |
+
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
| 334 |
+
|
| 335 |
+
return posemb
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def resample_abs_pos_embed_nhwdc(
|
| 339 |
+
posemb: torch.Tensor,
|
| 340 |
+
new_size: List[int],
|
| 341 |
+
interpolation: str = 'trilinear',
|
| 342 |
+
):
|
| 343 |
+
if new_size[0] == posemb.shape[-4] and new_size[1] == posemb.shape[-3] and new_size[2] == posemb.shape[-2]:
|
| 344 |
+
return posemb
|
| 345 |
+
|
| 346 |
+
orig_dtype = posemb.dtype
|
| 347 |
+
posemb = posemb.float()
|
| 348 |
+
posemb = posemb.reshape(1, posemb.shape[-4], posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 4, 1, 2, 3)
|
| 349 |
+
posemb = F.interpolate(posemb, size=new_size, mode=interpolation)
|
| 350 |
+
posemb = posemb.permute(0, 2, 3, 4, 1).to(orig_dtype)
|
| 351 |
+
|
| 352 |
+
return posemb
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def resample_patch_embed(
|
| 356 |
+
patch_embed,
|
| 357 |
+
new_size: List[int],
|
| 358 |
+
interpolation: str = 'trilinear',
|
| 359 |
+
):
|
| 360 |
+
"""Resample the weights of the patch embedding kernel to target resolution.
|
| 361 |
+
We resample the patch embedding kernel by approximately inverting the effect
|
| 362 |
+
of patch resizing.
|
| 363 |
+
|
| 364 |
+
Code based on:
|
| 365 |
+
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
|
| 366 |
+
|
| 367 |
+
With this resizing, we can for example load a B/8 filter into a B/16 model
|
| 368 |
+
and, on 2x larger input image, the result will match.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
patch_embed: original parameter to be resized.
|
| 372 |
+
new_size (tuple[int, int, int]): target shape (depth, height, width).
|
| 373 |
+
interpolation (str): interpolation for resize
|
| 374 |
+
Returns:
|
| 375 |
+
Resized patch embedding kernel.
|
| 376 |
+
"""
|
| 377 |
+
import numpy as np
|
| 378 |
+
try:
|
| 379 |
+
from torch import vmap
|
| 380 |
+
except ImportError:
|
| 381 |
+
from functorch import vmap
|
| 382 |
+
|
| 383 |
+
assert len(patch_embed.shape) == 5, "Five dimensions expected"
|
| 384 |
+
assert len(new_size) == 3, "New shape should only be (height, width, depth)"
|
| 385 |
+
old_size = patch_embed.shape[-3:]
|
| 386 |
+
if tuple(old_size) == tuple(new_size):
|
| 387 |
+
return patch_embed
|
| 388 |
+
|
| 389 |
+
def resize(x_np, _new_size):
|
| 390 |
+
x_tf = torch.Tensor(x_np)[None, None, ...]
|
| 391 |
+
x_upsampled = F.interpolate(
|
| 392 |
+
x_tf, size=_new_size, mode=interpolation)[0, 0, ...].numpy()
|
| 393 |
+
return x_upsampled
|
| 394 |
+
|
| 395 |
+
def get_resize_mat(_old_size, _new_size):
|
| 396 |
+
mat = []
|
| 397 |
+
for i in range(np.prod(_old_size)):
|
| 398 |
+
basis_vec = np.zeros(_old_size)
|
| 399 |
+
basis_vec[np.unravel_index(i, _old_size)] = 1.
|
| 400 |
+
mat.append(resize(basis_vec, _new_size).reshape(-1))
|
| 401 |
+
return np.stack(mat).T
|
| 402 |
+
|
| 403 |
+
resize_mat = get_resize_mat(old_size, new_size)
|
| 404 |
+
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
|
| 405 |
+
|
| 406 |
+
def resample_kernel(kernel):
|
| 407 |
+
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
| 408 |
+
return resampled_kernel.reshape(new_size)
|
| 409 |
+
|
| 410 |
+
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
| 411 |
+
orig_dtype = patch_embed.dtype
|
| 412 |
+
patch_embed = patch_embed.float()
|
| 413 |
+
patch_embed = v_resample_kernel(patch_embed)
|
| 414 |
+
patch_embed = patch_embed.to(orig_dtype)
|
| 415 |
+
return patch_embed
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def feature_take_indices(
|
| 419 |
+
num_features: int,
|
| 420 |
+
indices: Optional[Union[int, List[int]]] = None,
|
| 421 |
+
as_set: bool = False,
|
| 422 |
+
) -> Tuple[List[int], int]:
|
| 423 |
+
""" Determine the absolute feature indices to 'take' from.
|
| 424 |
+
|
| 425 |
+
Note: This function can be called in forward() so must be torchscript compatible,
|
| 426 |
+
which requires some incomplete typing and workaround hacks.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
num_features: total number of features to select from
|
| 430 |
+
indices: indices to select,
|
| 431 |
+
None -> select all
|
| 432 |
+
int -> select last n
|
| 433 |
+
list/tuple of int -> return specified (-ve indices specify from end)
|
| 434 |
+
as_set: return as a set
|
| 435 |
+
|
| 436 |
+
Returns:
|
| 437 |
+
List (or set) of absolute (from beginning) indices, Maximum index
|
| 438 |
+
"""
|
| 439 |
+
if indices is None:
|
| 440 |
+
indices = num_features # all features if None
|
| 441 |
+
|
| 442 |
+
if isinstance(indices, int):
|
| 443 |
+
# convert int -> last n indices
|
| 444 |
+
assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})'
|
| 445 |
+
take_indices = [num_features - indices + i for i in range(indices)]
|
| 446 |
+
else:
|
| 447 |
+
take_indices: List[int] = []
|
| 448 |
+
for i in indices:
|
| 449 |
+
idx = num_features + i if i < 0 else i
|
| 450 |
+
assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})'
|
| 451 |
+
take_indices.append(idx)
|
| 452 |
+
|
| 453 |
+
if not torch.jit.is_scripting() and as_set:
|
| 454 |
+
return set(take_indices), max(take_indices)
|
| 455 |
+
|
| 456 |
+
return take_indices, max(take_indices)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def global_pool_nlc(
|
| 460 |
+
x: torch.Tensor,
|
| 461 |
+
pool_type: str = 'token',
|
| 462 |
+
num_prefix_tokens: int = 1,
|
| 463 |
+
reduce_include_prefix: bool = False,
|
| 464 |
+
):
|
| 465 |
+
if not pool_type:
|
| 466 |
+
return x
|
| 467 |
+
|
| 468 |
+
if pool_type == 'token':
|
| 469 |
+
x = x[:, 0] # class token
|
| 470 |
+
else:
|
| 471 |
+
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
|
| 472 |
+
if pool_type == 'avg':
|
| 473 |
+
x = x.mean(dim=1)
|
| 474 |
+
elif pool_type == 'avgmax':
|
| 475 |
+
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
|
| 476 |
+
elif pool_type == 'max':
|
| 477 |
+
x = x.amax(dim=1)
|
| 478 |
+
else:
|
| 479 |
+
assert not pool_type, f'Unknown pool type {pool_type}'
|
| 480 |
+
|
| 481 |
+
return x
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def cat_keep_shapes(
|
| 485 |
+
x_list: List[torch.Tensor]
|
| 486 |
+
) -> Tuple[torch.Tensor, List[Tuple[int, ...]], List[int]]:
|
| 487 |
+
if not x_list:
|
| 488 |
+
return torch.empty(0), [], []
|
| 489 |
+
|
| 490 |
+
shapes = [x.shape for x in x_list]
|
| 491 |
+
num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
|
| 492 |
+
x_cat = torch.cat([x.flatten(0, -2) for x in x_list], dim=0)
|
| 493 |
+
|
| 494 |
+
return x_cat, shapes, num_tokens
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def uncat_with_shapes(
|
| 498 |
+
x_cat: torch.Tensor,
|
| 499 |
+
shapes: List[Tuple[int, ...]],
|
| 500 |
+
num_tokens: List[int]
|
| 501 |
+
) -> List[torch.Tensor]:
|
| 502 |
+
if not shapes:
|
| 503 |
+
return []
|
| 504 |
+
|
| 505 |
+
x_splitted = torch.split_with_sizes(x_cat, num_tokens, dim=0)
|
| 506 |
+
shapes_adjusted = [shape[:-1] + torch.Size([x_cat.shape[-1]]) for shape in shapes]
|
| 507 |
+
outputs_reshape = [x.reshape(shape) for x, shape in zip(x_splitted, shapes_adjusted)]
|
| 508 |
+
|
| 509 |
+
return outputs_reshape
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def last_token_pool(
|
| 513 |
+
last_hidden_states: torch.Tensor,
|
| 514 |
+
attention_mask: torch.Tensor
|
| 515 |
+
) -> torch.Tensor:
|
| 516 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 517 |
+
if left_padding:
|
| 518 |
+
return last_hidden_states[:, -1]
|
| 519 |
+
else:
|
| 520 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 521 |
+
batch_size = last_hidden_states.shape[0]
|
| 522 |
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device),
|
| 523 |
+
sequence_lengths]
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class Format(str, Enum):
|
| 527 |
+
NCHWD = 'NCHWD'
|
| 528 |
+
NHWDC = 'NHWDC'
|
| 529 |
+
NCL = 'NCL'
|
| 530 |
+
NLC = 'NLC'
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def nchwd_to(x: torch.Tensor, fmt: Format):
|
| 534 |
+
if fmt == Format.NHWDC:
|
| 535 |
+
x = x.permute(0, 2, 3, 4, 1)
|
| 536 |
+
elif fmt == Format.NLC:
|
| 537 |
+
x = x.flatten(2).transpose(1, 2)
|
| 538 |
+
elif fmt == Format.NCL:
|
| 539 |
+
x = x.flatten(2)
|
| 540 |
+
return x
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def nhwdc_to(x: torch.Tensor, fmt: Format):
|
| 544 |
+
if fmt == Format.NCHWD:
|
| 545 |
+
x = x.permute(0, 4, 1, 2, 3)
|
| 546 |
+
elif fmt == Format.NLC:
|
| 547 |
+
x = x.flatten(1, 2)
|
| 548 |
+
elif fmt == Format.NCL:
|
| 549 |
+
x = x.flatten(1, 2).transpose(1, 2)
|
| 550 |
+
return x
|
spectre/utils/param_groups.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_vit_lr_decay_rate(
|
| 5 |
+
name: str,
|
| 6 |
+
llrd_factor: float = 1.0,
|
| 7 |
+
num_layers: int = 12,
|
| 8 |
+
force_is_backbone: bool = False,
|
| 9 |
+
shift: int = 0,
|
| 10 |
+
) -> float:
|
| 11 |
+
"""
|
| 12 |
+
Get the layer-wise learning rate decay (LLRD) rate for a given parameter name.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
name:
|
| 16 |
+
The name of the parameter.
|
| 17 |
+
llrd_factor:
|
| 18 |
+
The decay factor for each layer.
|
| 19 |
+
num_layers:
|
| 20 |
+
The total number of layers in the model.
|
| 21 |
+
force_is_backbone:
|
| 22 |
+
If True, forces the function to treat the parameter as part of the backbone.
|
| 23 |
+
shift:
|
| 24 |
+
An integer to shift the layer ids, useful when combining multiple modules.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
The learning rate multiplier for the parameter.
|
| 28 |
+
"""
|
| 29 |
+
layer_id = num_layers + 1
|
| 30 |
+
if name.startswith("backbone") or force_is_backbone:
|
| 31 |
+
if (
|
| 32 |
+
".pos_embed" in name
|
| 33 |
+
or ".patch_embed" in name
|
| 34 |
+
or ".patch_proj" in name
|
| 35 |
+
or ".mask_token" in name
|
| 36 |
+
or ".cls_token" in name
|
| 37 |
+
or ".reg_token" in name
|
| 38 |
+
):
|
| 39 |
+
layer_id = 0
|
| 40 |
+
elif ".blocks." in name:
|
| 41 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + shift
|
| 42 |
+
|
| 43 |
+
return llrd_factor ** (num_layers + 1 - layer_id)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_param_groups_with_decay(
|
| 47 |
+
model: nn.Module,
|
| 48 |
+
llrd_factor: float = 1.0,
|
| 49 |
+
patch_embed_lr_mult: float = 1.0,
|
| 50 |
+
projection_head_wd_mult: float = 1.0,
|
| 51 |
+
lora_lr_factor: float = 1.0,
|
| 52 |
+
num_layers: int | None = None,
|
| 53 |
+
):
|
| 54 |
+
|
| 55 |
+
force_is_backbone = False
|
| 56 |
+
shift = 0
|
| 57 |
+
if num_layers is not None:
|
| 58 |
+
num_layers = num_layers
|
| 59 |
+
elif hasattr(model, "n_blocks"):
|
| 60 |
+
num_layers = model.n_blocks
|
| 61 |
+
force_is_backbone = True
|
| 62 |
+
elif hasattr(model, "blocks"):
|
| 63 |
+
num_layers = len(model.blocks)
|
| 64 |
+
force_is_backbone = True
|
| 65 |
+
elif hasattr(model, "backbone") and hasattr(model.backbone, "blocks"):
|
| 66 |
+
num_layers = len(model.backbone.blocks)
|
| 67 |
+
elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "blocks"): # DINO specific
|
| 68 |
+
num_layers = len(model.backbone_student.blocks)
|
| 69 |
+
elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "vit") and hasattr(model.backbone_student.vit, "blocks"): # DINOv2 specific
|
| 70 |
+
num_layers = len(model.backbone_student.vit.blocks)
|
| 71 |
+
elif hasattr(model, "backbone_image") and hasattr(model.backbone_image, "blocks"): # SigLIP specific
|
| 72 |
+
if not hasattr(model, "feature_comb_image") or model.feature_comb_image is None:
|
| 73 |
+
num_layers = len(model.backbone_image.blocks)
|
| 74 |
+
else:
|
| 75 |
+
num_layers = len(model.backbone_image.blocks) + len(model.feature_comb_image.blocks)
|
| 76 |
+
shift = len(model.backbone_image.blocks)
|
| 77 |
+
force_is_backbone = True
|
| 78 |
+
else:
|
| 79 |
+
num_layers = 0
|
| 80 |
+
|
| 81 |
+
all_param_groups = []
|
| 82 |
+
for n, p in model.named_parameters():
|
| 83 |
+
if not p.requires_grad:
|
| 84 |
+
continue
|
| 85 |
+
if not "lora_" in n:
|
| 86 |
+
s = shift if "feature_comb" in n else 0
|
| 87 |
+
llrd_rate = get_vit_lr_decay_rate(
|
| 88 |
+
n, llrd_factor, num_layers, force_is_backbone, s,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
d = {
|
| 92 |
+
"name": n,
|
| 93 |
+
"params": p,
|
| 94 |
+
"lr_mult": llrd_rate,
|
| 95 |
+
"wd_mult": 1.0,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
if "head" in n or "projection" in n:
|
| 99 |
+
d["wd_mult"] = projection_head_wd_mult
|
| 100 |
+
|
| 101 |
+
# No weight-decay on biases, norm parameters, layer scale gamma, learned tokens and embeddings
|
| 102 |
+
if n.endswith("bias") or "norm" in n or "gamma" in n or "fourrier_w" in n:
|
| 103 |
+
d["wd_mult"] = 0.0
|
| 104 |
+
|
| 105 |
+
if "patch_embed" in n:
|
| 106 |
+
d["lr_mult"] *= patch_embed_lr_mult
|
| 107 |
+
|
| 108 |
+
else:
|
| 109 |
+
# LoRA parameters
|
| 110 |
+
d = {
|
| 111 |
+
"name": n,
|
| 112 |
+
"params": p,
|
| 113 |
+
"lr_mult": lora_lr_factor,
|
| 114 |
+
"wd_mult": 1.0,
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
all_param_groups.append(d)
|
| 118 |
+
return all_param_groups
|
spectre/utils/scheduler.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def linear_warmup_schedule(
|
| 9 |
+
step: int,
|
| 10 |
+
warmup_steps: int,
|
| 11 |
+
start_value: float,
|
| 12 |
+
end_value: float,
|
| 13 |
+
) -> float:
|
| 14 |
+
if warmup_steps < 0:
|
| 15 |
+
raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
|
| 16 |
+
if step < 0:
|
| 17 |
+
raise ValueError(f"Current step number {step} can't be negative.")
|
| 18 |
+
if start_value < 0:
|
| 19 |
+
raise ValueError(f"Start value {start_value} can't be negative.")
|
| 20 |
+
if end_value <= 0:
|
| 21 |
+
raise ValueError(f"End value {end_value} can't be non-positive.")
|
| 22 |
+
if start_value > end_value:
|
| 23 |
+
raise ValueError(
|
| 24 |
+
f"Start value {start_value} must be less than or equal to end value {end_value}."
|
| 25 |
+
)
|
| 26 |
+
if step < warmup_steps:
|
| 27 |
+
return start_value + step / warmup_steps * (end_value - start_value)
|
| 28 |
+
else:
|
| 29 |
+
return end_value
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def cosine_schedule(
|
| 33 |
+
step: int,
|
| 34 |
+
max_steps: int,
|
| 35 |
+
start_value: float,
|
| 36 |
+
end_value: float,
|
| 37 |
+
period: Optional[int] = None,
|
| 38 |
+
) -> float:
|
| 39 |
+
"""Use cosine decay to gradually modify start_value to reach target end_value.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
step:
|
| 43 |
+
Current step number.
|
| 44 |
+
max_steps:
|
| 45 |
+
Total number of steps.
|
| 46 |
+
start_value:
|
| 47 |
+
Starting value.
|
| 48 |
+
end_value:
|
| 49 |
+
Target value.
|
| 50 |
+
period:
|
| 51 |
+
The number of steps over which the cosine function completes a full cycle.
|
| 52 |
+
Defaults to max_steps.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Cosine decay value.
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
if step < 0:
|
| 59 |
+
raise ValueError(f"Current step number {step} can't be negative.")
|
| 60 |
+
if max_steps < 1:
|
| 61 |
+
raise ValueError(f"Total step number {max_steps} must be >= 1.")
|
| 62 |
+
if period is None and step > max_steps:
|
| 63 |
+
warnings.warn(
|
| 64 |
+
f"Current step number {step} exceeds max_steps {max_steps}.",
|
| 65 |
+
category=RuntimeWarning,
|
| 66 |
+
)
|
| 67 |
+
if period is not None and period <= 0:
|
| 68 |
+
raise ValueError(f"Period {period} must be >= 1")
|
| 69 |
+
|
| 70 |
+
decay: float
|
| 71 |
+
if period is not None: # "cycle" based on period, if provided
|
| 72 |
+
decay = (
|
| 73 |
+
end_value
|
| 74 |
+
- (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
|
| 75 |
+
)
|
| 76 |
+
elif max_steps == 1:
|
| 77 |
+
# Avoid division by zero
|
| 78 |
+
decay = end_value
|
| 79 |
+
elif step == max_steps:
|
| 80 |
+
# Special case for Pytorch Lightning which updates LR scheduler also for epoch
|
| 81 |
+
# after last training epoch.
|
| 82 |
+
decay = end_value
|
| 83 |
+
else:
|
| 84 |
+
decay = (
|
| 85 |
+
end_value
|
| 86 |
+
- (end_value - start_value)
|
| 87 |
+
* (np.cos(np.pi * step / (max_steps - 1)) + 1)
|
| 88 |
+
/ 2
|
| 89 |
+
)
|
| 90 |
+
return decay
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def cosine_warmup_schedule(
|
| 94 |
+
step: int,
|
| 95 |
+
max_steps: int,
|
| 96 |
+
start_value: float,
|
| 97 |
+
end_value: float,
|
| 98 |
+
warmup_steps: int,
|
| 99 |
+
warmup_start_value: float,
|
| 100 |
+
warmup_end_value: Optional[float] = None,
|
| 101 |
+
period: Optional[int] = None,
|
| 102 |
+
) -> float:
|
| 103 |
+
"""Use cosine decay to gradually modify start_value to reach target end_value.
|
| 104 |
+
|
| 105 |
+
Uses linear warmup for the first warmup_steps steps.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
step:
|
| 109 |
+
Current step number.
|
| 110 |
+
max_steps:
|
| 111 |
+
Total number of steps.
|
| 112 |
+
start_value:
|
| 113 |
+
Starting value.
|
| 114 |
+
end_value:
|
| 115 |
+
Target value.
|
| 116 |
+
warmup_steps:
|
| 117 |
+
Number of steps for warmup.
|
| 118 |
+
warmup_start_value:
|
| 119 |
+
Starting value for warmup.
|
| 120 |
+
warmup_end_value:
|
| 121 |
+
Target value for warmup. Defaults to start_value.
|
| 122 |
+
period:
|
| 123 |
+
The number of steps over which the cosine function completes a full cycle.
|
| 124 |
+
Defaults to max_steps - warmup_steps.
|
| 125 |
+
Returns:
|
| 126 |
+
Cosine decay value.
|
| 127 |
+
"""
|
| 128 |
+
if warmup_steps < 0:
|
| 129 |
+
raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
|
| 130 |
+
if warmup_steps > max_steps:
|
| 131 |
+
raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps.")
|
| 132 |
+
if step > max_steps:
|
| 133 |
+
warnings.warn(
|
| 134 |
+
f"Current step number {step} exceeds max_steps {max_steps}.",
|
| 135 |
+
category=RuntimeWarning,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if warmup_end_value is None:
|
| 139 |
+
warmup_end_value = start_value
|
| 140 |
+
|
| 141 |
+
if step < warmup_steps:
|
| 142 |
+
# Use step + 1 to reach warmup_end_value at end of warmup. This means that the
|
| 143 |
+
# initial warmup_start_value is skipped which is oftentimes desired when setting
|
| 144 |
+
# it to 0 as this would result in no parameter updates.
|
| 145 |
+
return (
|
| 146 |
+
warmup_start_value
|
| 147 |
+
+ (warmup_end_value - warmup_start_value) * (step + 1) / warmup_steps
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
max_steps = max_steps - warmup_steps if period is None else 1
|
| 151 |
+
return cosine_schedule(
|
| 152 |
+
step=step - warmup_steps,
|
| 153 |
+
max_steps=max_steps,
|
| 154 |
+
start_value=start_value,
|
| 155 |
+
end_value=end_value,
|
| 156 |
+
period=period,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):
|
| 161 |
+
"""Cosine warmup scheduler for learning rate.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
optimizer:
|
| 165 |
+
Optimizer object to schedule the learning rate.
|
| 166 |
+
warmup_epochs:
|
| 167 |
+
Number of warmup epochs or steps.
|
| 168 |
+
max_epochs:
|
| 169 |
+
Total number of training epochs or steps.
|
| 170 |
+
last_epoch:
|
| 171 |
+
The index of last epoch or step.
|
| 172 |
+
start_value:
|
| 173 |
+
Starting learning rate.
|
| 174 |
+
end_value:
|
| 175 |
+
Target learning rate.
|
| 176 |
+
verbose:
|
| 177 |
+
If True, prints a message to stdout for each update.
|
| 178 |
+
warmup_start_value:
|
| 179 |
+
Starting learning rate for warmup.
|
| 180 |
+
warmup_end_value:
|
| 181 |
+
Target learning rate for warmup. Defaults to start_value.
|
| 182 |
+
|
| 183 |
+
Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index
|
| 184 |
+
can be used. The naming follows the PyTorch convention to use `epoch` for the steps
|
| 185 |
+
in the scheduler.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
optimizer: torch.optim.Optimizer,
|
| 191 |
+
warmup_epochs: int,
|
| 192 |
+
max_epochs: int,
|
| 193 |
+
last_epoch: int = -1,
|
| 194 |
+
start_value: float = 1.0,
|
| 195 |
+
end_value: float = 0.001,
|
| 196 |
+
period: Optional[int] = None,
|
| 197 |
+
verbose: bool = False,
|
| 198 |
+
warmup_start_value: float = 0.0,
|
| 199 |
+
warmup_end_value: Optional[float] = None,
|
| 200 |
+
) -> None:
|
| 201 |
+
self.warmup_epochs = warmup_epochs
|
| 202 |
+
self.max_epochs = max_epochs
|
| 203 |
+
self.start_value = start_value
|
| 204 |
+
self.end_value = end_value
|
| 205 |
+
self.period = period
|
| 206 |
+
self.warmup_start_value = warmup_start_value
|
| 207 |
+
self.warmup_end_value = warmup_end_value
|
| 208 |
+
|
| 209 |
+
super().__init__(
|
| 210 |
+
optimizer=optimizer,
|
| 211 |
+
lr_lambda=self.scale_lr,
|
| 212 |
+
last_epoch=last_epoch,
|
| 213 |
+
verbose=verbose,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def scale_lr(self, epoch: int) -> float:
|
| 217 |
+
"""Scale learning rate according to the current epoch number.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
epoch:
|
| 221 |
+
Current epoch number.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Scaled learning rate.
|
| 225 |
+
|
| 226 |
+
"""
|
| 227 |
+
return cosine_warmup_schedule(
|
| 228 |
+
step=epoch,
|
| 229 |
+
max_steps=self.max_epochs,
|
| 230 |
+
start_value=self.start_value,
|
| 231 |
+
end_value=self.end_value,
|
| 232 |
+
warmup_steps=self.warmup_epochs,
|
| 233 |
+
warmup_start_value=self.warmup_start_value,
|
| 234 |
+
warmup_end_value=self.warmup_end_value,
|
| 235 |
+
period=self.period,
|
| 236 |
+
)
|