cclaess commited on
Commit
8b41845
·
verified ·
1 Parent(s): 0f95598

Initial commit

Browse files
README.md ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 📢 [2026-04-10] SPECTRE is now an official baseline for the [**CVPR 2026 Workshop Competition: Foundation Models for General CT Image Diagnosis**](https://www.codabench.org/competitions/12650/)! See `experiments/cvpr26_fm_for_ct_diag_task_1` for scripts and additional details.
2
+
3
+ 📢 [2026-02-21] SPECTRE has been accepted for presentation at **CVPR 2026** (Denver, Colorado, USA)!
4
+
5
+ 📢 [2026-01-20] [Semantic segmentation](https://github.com/cviviers/nnUNet) code and configurations using the nnUNet framework are now released!
6
+
7
+
8
+ # SPECTRE 👻👻👻
9
+
10
+ <p align="center">
11
+ <a href="https://pypi.org/project/spectre-fm/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/spectre-fm?style=flat-square&label=version&cacheSeconds=0" /></a>
12
+ <a href="https://pypi.org/project/spectre-fm/"><img alt="Python Versions" src="https://img.shields.io/pypi/pyversions/spectre-fm?style=flat-square&cacheSeconds=0" /></a>
13
+ <a href="https://pypi.org/project/spectre-fm/"><img alt="Downloads per Month" src="https://img.shields.io/pypi/dm/spectre-fm?style=flat-square&label=downloads&cacheSeconds=0" /></a>
14
+ <a href="https://github.com/cclaess/SPECTRE/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/cclaess/SPECTRE?style=flat-square&cacheSeconds=0" /></a>
15
+ <a href="https://huggingface.co/cclaess/SPECTRE"><img alt="Model weights" src="https://img.shields.io/badge/models-Hugging%20Face-yellow?style=flat-square&cacheSeconds=0" /></a>
16
+ <a href="https://arxiv.org/abs/2511.17209"><img alt="Paper" src="https://img.shields.io/badge/paper-arXiv-b31b1b?style=flat-square&cacheSeconds=0" /></a>
17
+ </p>
18
+
19
+ <p align="center">
20
+ <img src="imgs/method_overview.jpg" alt="SPECTRE architecture and pretraining strategies" width="600"/>
21
+ </p>
22
+
23
+ SPECTRE (**S**elf-Supervised & Cross-Modal **P**r**e**training for **CT** **R**epresentation **E**xtraction) is a **Transformer-based foundation model for 3D Computed Tomography (CT) scans**, trained using **self-supervised learning** (SSL) and **cross-modal vision–language alignment** (VLA). It provides rich and generalizable representations from medical imaging data, which can be fine-tuned for downstream tasks such as segmentation, classification, and anomaly detection.
24
+
25
+ SPECTRE has been trained on a large cohort of **open-source CT scans** of the **human abdomen and thorax**, as well as **paired radiology reports** and **Electronic Health Record data**, enabling it to capture representations that generalize across datasets and clinical settings.
26
+
27
+ This repository provides pretrained SPECTRE models together with tools for fine-tuning and evaluation.
28
+
29
+ ## 🧠 Pretrained Models
30
+ The pretrained SPECTRE model can easily be imported as follows:
31
+
32
+ ```python
33
+ from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS
34
+ import torch
35
+
36
+ config = MODEL_CONFIGS['spectre-large-pretrained']
37
+ model = SpectreImageFeatureExtractor.from_config(config)
38
+ model.eval()
39
+
40
+ # Dummy input: (batch, crops, channels, height, width, depth)
41
+ # For a (3 x 3 x 4) grid of (128 x 128 x 64) CT patches -> Total scan size (384 x 384 x 256)
42
+ x = torch.randn(1, 1, 384, 384, 256)
43
+ B, C, H, W, D = x.shape
44
+
45
+ patch_size = (128, 128, 64)
46
+ pH, pW, pD = patch_size
47
+
48
+ x = x.view(
49
+ B, C,
50
+ H // pH, pH,
51
+ W // pW, pW,
52
+ D // pD, pD,
53
+ ).permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(B, -1, C, pH, pW, pD)
54
+
55
+ with torch.no_grad():
56
+ features = model(
57
+ x,
58
+ grid_size=(
59
+ H // pH,
60
+ W // pW,
61
+ D // pD,
62
+ ),
63
+ )
64
+ print("Features shape:", features.shape)
65
+ ```
66
+
67
+ Alternatively, you can download the weights of the separate components through HuggingFace using the following links:
68
+
69
+ | Architecture | Input Modality | Pretraining Objective | Model Weights |
70
+ |---------------------------|--------------------|-------------------------|-----------------------------------------------------------------------------------------------------------------------------|
71
+ | SPECTRE-ViT-Local | CT crops | SSL | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128_no_vla.pt?download=true) |
72
+ | SPECTRE-ViT-Local | CT crops | SSL + VLA | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128.pt?download=true) |
73
+ | SPECTRE-ViT-Global | Embedded CT crops | VLA | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_combiner_feature_vit_large.pt?download=true) |
74
+ | Qwen3-Embedding-0.6B LoRA | Text (radiology) | VLA | [Link](https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_qwen3_embedding_0.6B_lora.pt?download=true) |
75
+
76
+ ## 🩻 Segmentation (nnUNet)
77
+
78
+ If you're looking for a nnUNet-based segmentation pipeline that uses SPECTRE as the backbone, see: https://github.com/cviviers/nnUNet
79
+
80
+ ## 📂 Repository Contents
81
+
82
+ This repository is organized as follows:
83
+
84
+ - 🚀 **`src/spectre/`** – Contains the core package, including:
85
+ - Pretraining methods
86
+ - Model architectures
87
+ - Data handling and transformations
88
+
89
+ - 🛠️ **`src/spectre/configs/`** – Stores configuration files for different training settings.
90
+
91
+ - 🔬 **`experiments/`** – Includes Python scripts for running various pretraining and downstream experiments.
92
+
93
+ - 🐳 **`Dockerfile`** – Defines the environment for running a local version of SPECTRE inside a container.
94
+
95
+ ## ⚙️ Setting Up the Environment
96
+
97
+ To get up and running with SPECTRE, install the base package with pip:
98
+
99
+ ```bash
100
+ pip install spectre-fm
101
+ ```
102
+
103
+ This installs only the runtime dependencies needed to load and run the pretrained models.
104
+
105
+ If you want to fine-tune or pretrain SPECTRE, install the matching extra:
106
+
107
+ ```bash
108
+ pip install "spectre-fm[training]"
109
+ ```
110
+
111
+ If you only need the evaluation stack, install:
112
+
113
+ ```bash
114
+ pip install "spectre-fm[eval]"
115
+ ```
116
+
117
+ If training on GDS-enabled systems is required, install the CUDA 12 specific extra:
118
+
119
+ ```bash
120
+ pip install "spectre-fm[gds-cuda12]" # with training stack: "spectre-fm[training,gds-cuda12]"
121
+ ```
122
+
123
+ **Note that** `gds-cuda12` is only compatible with CUDA 12.x environments.
124
+
125
+ To install everything at once, use:
126
+
127
+ ```bash
128
+ pip install "spectre-fm[all]"
129
+ ```
130
+
131
+ or install the latest updates directly from GitHub:
132
+
133
+ ```bash
134
+ pip install git+https://github.com/cclaess/SPECTRE.git
135
+ ```
136
+
137
+ ## 🐳 Building and Using Docker
138
+
139
+ To facilitate deployment and reproducibility, SPECTRE can be run using **Docker**. This allows you to set up a fully functional environment without manually installing dependencies using your own local copy of spectre.
140
+
141
+ ### **Building the Docker Image**
142
+ First, ensure you have **Docker** installed. Then, clone and navigate to the repository to build the image:
143
+ ```bash
144
+ git clone https://github.com/cclaess/SPECTRE
145
+ cd SPECTRE
146
+ docker build -t spectre-fm .
147
+ ```
148
+
149
+ ### **Running Experiments Inside Docker**
150
+ Once the image is built, you can start a container and execute scripts inside it. For example, to run a DINO pretraining experiment:
151
+ ```bash
152
+ docker run --gpus all --rm -v "$(pwd):/mnt" spectre-fm python3 experiments/pretraining/pretrain_dino.py --config_file spectre/configs/dino_default.yaml --output_dir /mnt/outputs/pretraining/dino/
153
+ ```
154
+ - `--gpus all` enables GPU acceleration if available.
155
+ - `--rm` removes the container after execution.
156
+ - `-v $(pwd):/mnt` mounts the current directory inside the container.
157
+
158
+ ## ⚖️ License
159
+ - **Code: MIT** — see `LICENSE` (permissive; commercial use permitted).
160
+ - **Pretrained model weights: CC-BY-NC-SA** — non-commercial share-alike. The weights and any derivative models that include these weights are NOT cleared for commercial use. See `LICENSE_MODELS` for details and the precise license text.
161
+
162
+ > Note: the pretrained weights are subject to the original dataset licenses. Users intending to use SPECTRE in commercial settings should verify dataset and model licensing and obtain any required permissions.
163
+
164
+ ## 📜 Citation
165
+ If you use SPECTRE in your research or wish to cite it, please use the following BibTeX entry of our [preprint](https://arxiv.org/abs/2511.17209):
166
+ ```
167
+ @misc{claessens_scaling_2025,
168
+ title = {Scaling {Self}-{Supervised} and {Cross}-{Modal} {Pretraining} for {Volumetric} {CT} {Transformers}},
169
+ url = {http://arxiv.org/abs/2511.17209},
170
+ doi = {10.48550/arXiv.2511.17209},
171
+ author = {Claessens, Cris and Viviers, Christiaan and D'Amicantonio, Giacomo and Bondarev, Egor and Sommen, Fons van der},
172
+ year={2025},
173
+ }
174
+ ```
175
+
176
+ ## 🤝 Acknowledgements
177
+ This project builds upon prior work in self-supervised learning, medical imaging, and transformer-based representation learning. We especially acknowledge [**MONAI**](https://project-monai.github.io/) for their awesome framework and the [**timm**](https://timm.fast.ai/) & [**lightly**](https://docs.lightly.ai/self-supervised-learning/) Python libraries for providing 2D PyTorch models (timm) and object-oriented self-supervised learning methods (lightly), from which we adapted parts of the code for 3D.
178
+
179
+ [![Star History Chart](https://api.star-history.com/svg?repos=cclaess/SPECTRE&type=Date)](https://star-history.com/#cclaess/SPECTRE&Date)
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SpectreModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_spectre.SpectreConfig",
7
+ "AutoModel": "modeling_spectre.SpectreModel"
8
+ },
9
+ "backbone_kwargs": {
10
+ "global_pool": "",
11
+ "init_values": 1.0,
12
+ "num_classes": 0,
13
+ "pos_embed": "rope",
14
+ "rope_kwargs": {
15
+ "base": 1000.0
16
+ }
17
+ },
18
+ "backbone_name": "vit_large_patch16_128",
19
+ "dtype": "float32",
20
+ "feature_combiner_kwargs": {
21
+ "global_pool": "",
22
+ "init_values": 1.0,
23
+ "num_classes": 0,
24
+ "pos_embed": "rope",
25
+ "rope_kwargs": {
26
+ "base": 100.0
27
+ }
28
+ },
29
+ "feature_combiner_name": "feat_vit_large",
30
+ "model_type": "spectre",
31
+ "transformers_version": "5.3.0"
32
+ }
configuration_spectre.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class SpectreConfig(PretrainedConfig):
5
+ model_type = "spectre"
6
+
7
+ def __init__(
8
+ self,
9
+ backbone_name="vit_large_patch16_128",
10
+ backbone_kwargs={
11
+ "num_classes": 0,
12
+ "global_pool": '',
13
+ "pos_embed": "rope",
14
+ "rope_kwargs": {"base": 1000.0},
15
+ "init_values": 1.0,
16
+ },
17
+ feature_combiner_name="feat_vit_large",
18
+ feature_combiner_kwargs={
19
+ "num_classes": 0,
20
+ "global_pool": "",
21
+ "pos_embed": "rope",
22
+ "rope_kwargs": {"base": 100.0},
23
+ "init_values": 1.0,
24
+ },
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ self.backbone_name = backbone_name
30
+ self.backbone_kwargs = backbone_kwargs or {}
31
+ self.feature_combiner_name = feature_combiner_name
32
+ self.feature_combiner_kwargs = feature_combiner_kwargs or {}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c90810e90d833cfff54880997fad4d964963fef4dd6824789f44a5bed063cb61
3
+ size 1587720248
modeling_spectre.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel
3
+ from transformers.modeling_outputs import BaseModelOutput
4
+
5
+ from spectre.model import SpectreImageFeatureExtractor
6
+ from configuration_spectre import SpectreConfig
7
+
8
+
9
+ class SpectreModel(PreTrainedModel):
10
+ config_class = SpectreConfig
11
+ base_model_prefix = "spectre"
12
+ main_input_name = "pixel_values"
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+
17
+ self.model = SpectreImageFeatureExtractor(
18
+ backbone_name=config.backbone_name,
19
+ backbone_kwargs=config.backbone_kwargs,
20
+ feature_combiner_name=config.feature_combiner_name,
21
+ feature_combiner_kwargs=config.feature_combiner_kwargs,
22
+ )
23
+
24
+ self.post_init()
25
+
26
+ def forward(
27
+ self,
28
+ pixel_values: torch.Tensor,
29
+ grid_size=None,
30
+ return_dict=True,
31
+ **kwargs,
32
+ ):
33
+ outputs = self.model(pixel_values, grid_size=grid_size)
34
+
35
+ if not return_dict:
36
+ return (outputs,)
37
+
38
+ return BaseModelOutput(last_hidden_state=outputs)
spectre/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Top-level package for spectre.
2
+
3
+ Expose a small, stable public API here so users can do:
4
+
5
+ from spectre import SpectreImageFeatureExtractor, models
6
+
7
+ Keep implementations in subpackages; this file only re-exports the most
8
+ important symbols and subpackages for convenience.
9
+ """
10
+ from .model import SpectreImageFeatureExtractor, MODEL_CONFIGS
11
+ from . import models
12
+ from . import utils
13
+
14
+ __version__ = "0.1.0"
15
+ __author__ = "Cris Claessens"
16
+ __email__ = "c.h.b.claessens@tue.nl"
17
+
18
+ __all__ = [
19
+ "SpectreImageFeatureExtractor",
20
+ "MODEL_CONFIGS",
21
+ "models",
22
+ "data",
23
+ "transforms",
24
+ "ssl",
25
+ "utils",
26
+ "__version__",
27
+ "__author__",
28
+ "__email__",
29
+ ]
spectre/model.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ MODEL_CONFIGS = {
6
+ "spectre-small": {
7
+ "name": "spectre-small",
8
+ "backbone": "vit_small_patch16_128",
9
+ "backbone_checkpoint_path_or_url": None,
10
+ "backbone_kwargs": {
11
+ "num_classes": 0,
12
+ "global_pool": '',
13
+ "pos_embed": "rope",
14
+ "rope_kwargs": {"base": 1000.0},
15
+ "init_values": 1.0,
16
+ },
17
+ "feature_combiner": "feat_vit_small",
18
+ "feature_combiner_checkpoint_path_or_url": None,
19
+ "feature_combiner_kwargs": {
20
+ "num_classes": 0,
21
+ "global_pool": '',
22
+ "pos_embed": "rope",
23
+ "rope_kwargs": {"base": 100.0},
24
+ "init_values": 1.0,
25
+ },
26
+ "description": "SPECTRE model with ViT-Small backbone and feature combiner.",
27
+ }, # Pretrained/Distilled checkpoints will be added later
28
+ "spectre-base": {
29
+ "name": "spectre-base",
30
+ "backbone": "vit_base_patch16_128",
31
+ "backbone_checkpoint_path_or_url": None,
32
+ "backbone_kwargs": {
33
+ "num_classes": 0,
34
+ "global_pool": '',
35
+ "pos_embed": "rope",
36
+ "rope_kwargs": {"base": 1000.0},
37
+ "init_values": 1.0,
38
+ },
39
+ "feature_combiner": "feat_vit_base",
40
+ "feature_combiner_checkpoint_path_or_url": None,
41
+ "feature_combiner_kwargs": {
42
+ "num_classes": 0,
43
+ "global_pool": '',
44
+ "pos_embed": "rope",
45
+ "rope_kwargs": {"base": 100.0},
46
+ "init_values": 1.0,
47
+ },
48
+ "description": "SPECTRE model with ViT-Base backbone and feature combiner.",
49
+ }, # Pretrained/Distilled checkpoints will be added later
50
+ "spectre-large": {
51
+ "name": "spectre-large",
52
+ "backbone": "vit_large_patch16_128",
53
+ "backbone_checkpoint_path_or_url": None,
54
+ "backbone_kwargs": {
55
+ "num_classes": 0,
56
+ "global_pool": '',
57
+ "pos_embed": "rope",
58
+ "rope_kwargs": {"base": 1000.0},
59
+ "init_values": 1.0,
60
+ },
61
+ "feature_combiner": "feat_vit_large",
62
+ "feature_combiner_checkpoint_path_or_url": None,
63
+ "feature_combiner_kwargs": {
64
+ "num_classes": 0,
65
+ "global_pool": '',
66
+ "pos_embed": "rope",
67
+ "rope_kwargs": {"base": 100.0},
68
+ "init_values": 1.0,
69
+ },
70
+ "description": "SPECTRE model with ViT-Large backbone and feature combiner.",
71
+ },
72
+ "spectre-large-pretrained": {
73
+ "name": "spectre-large-pretrained",
74
+ "backbone": "vit_large_patch16_128",
75
+ "backbone_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_backbone_vit_large_patch16_128.pt?download=true",
76
+ "backbone_kwargs": {
77
+ "num_classes": 0,
78
+ "global_pool": '',
79
+ "pos_embed": "rope",
80
+ "rope_kwargs": {"base": 1000.0},
81
+ "init_values": 1.0,
82
+ },
83
+ "feature_combiner": "feat_vit_large",
84
+ "feature_combiner_checkpoint_path_or_url": "https://huggingface.co/cclaess/SPECTRE/resolve/main/spectre_combiner_feature_vit_large.pt?download=true",
85
+ "feature_combiner_kwargs": {
86
+ "num_classes": 0,
87
+ "global_pool": '',
88
+ "pos_embed": "rope",
89
+ "rope_kwargs": {"base": 100.0},
90
+ "init_values": 1.0,
91
+ },
92
+ "description": "Pretrained SPECTRE model with ViT-Large backbone and feature combiner.",
93
+ }
94
+ }
95
+
96
+
97
+ class SpectreImageFeatureExtractor(nn.Module):
98
+ def __init__(
99
+ self,
100
+ backbone_name: str,
101
+ backbone_kwargs: dict = {},
102
+ backbone_checkpoint_path_or_url: str | None = None,
103
+ feature_combiner_name: str | None = None,
104
+ feature_combiner_kwargs: dict = {},
105
+ feature_combiner_checkpoint_path_or_url: str | None = None,
106
+ **kwargs,
107
+ ):
108
+ super().__init__()
109
+ self.backbone = None
110
+ self.feature_combiner = None
111
+ self._init_backbone(
112
+ backbone_name,
113
+ checkpoint_path_or_url=backbone_checkpoint_path_or_url,
114
+ **backbone_kwargs,
115
+ **kwargs,
116
+ )
117
+ if feature_combiner_name is not None:
118
+ self._init_feature_combiner(
119
+ feature_combiner_name,
120
+ checkpoint_path_or_url=feature_combiner_checkpoint_path_or_url,
121
+ **feature_combiner_kwargs,
122
+ **kwargs,
123
+ )
124
+
125
+ def _init_backbone(
126
+ self,
127
+ model_name: str,
128
+ checkpoint_path_or_url: str | None = None,
129
+ **kwargs
130
+ ):
131
+ backbone_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
132
+ self.backbone = backbone_cls(
133
+ checkpoint_path_or_url=checkpoint_path_or_url,
134
+ **kwargs,
135
+ )
136
+
137
+ def _init_feature_combiner(
138
+ self,
139
+ model_name: str,
140
+ checkpoint_path_or_url: str | None = None,
141
+ **kwargs,
142
+ ):
143
+ if self.backbone.global_pool == '':
144
+ patch_dim = self.backbone.embed_dim * 2 # CLS + AVG pooled tokens
145
+ else:
146
+ patch_dim = self.backbone.embed_dim
147
+
148
+ feature_combiner_cls = getattr(__import__('spectre.models', fromlist=[model_name]), model_name)
149
+ self.feature_combiner = feature_combiner_cls(
150
+ patch_dim=patch_dim,
151
+ checkpoint_path_or_url=checkpoint_path_or_url,
152
+ **kwargs,
153
+ )
154
+
155
+ def extract_backbone_features(
156
+ self,
157
+ x: torch.Tensor,
158
+ ):
159
+ """
160
+ Extract features from the backbone for a batch of image sets. Input is expected to be of
161
+ shape (B, N, C, H, W, D), where B is the batch size, N is the number of image patches per
162
+ image, C is the number of channels, H is height, W is width, and D is depth.
163
+ The output will be a tensor of extracted features (B, N, T, F) where T is the number of
164
+ tokens and F is the feature dimension.
165
+
166
+ Args:
167
+ x (torch.Tensor): Input tensor of shape (B, N, C, H, W, D)
168
+ Returns:
169
+ torch.Tensor: Extracted features of shape (B, N, T, F)
170
+ """
171
+ assert x.ndim == 6, "Input tensor must have 6 dimensions: (B, N, C, H, W, D)"
172
+ B, N, C, H, W, D = x.shape
173
+ x = x.view(B * N, C, H, W, D)
174
+ features = self.backbone(x)
175
+ if features.ndim == 2: # only CLS token
176
+ features = features.unsqueeze(1)
177
+ features = features.view(B, N, features.shape[1], -1)
178
+ return features
179
+
180
+ def combine_features(
181
+ self,
182
+ features: torch.Tensor,
183
+ grid_size: tuple[int, int, int],
184
+ ):
185
+ """
186
+ Combine features from multiple image patches using the feature combiner.
187
+
188
+ Args:
189
+ features (torch.Tensor): Input features of shape (B, N, T, F)
190
+ grid_size (tuple[int, int, int]): Grid size of the image patches
191
+ Returns:
192
+ torch.Tensor: Combined features of shape (B, T', F')
193
+ """
194
+ _, N, T, _ = features.shape
195
+ assert features.ndim == 4, "Input features must have 4 dimensions: (B, N, T, F)"
196
+ assert N == grid_size[0] * grid_size[1] * grid_size[2], \
197
+ "Number of patches N must match the product of grid_size dimensions"
198
+
199
+ if T == 1: # only CLS token
200
+ features = features.squeeze(2)
201
+ else:
202
+ # We combine CLS tokens with AVG pooling of other tokens
203
+ features = torch.cat([
204
+ features[:, :, 0, :], # CLS token (B, N, F)
205
+ features[:, :, 1:, :].mean(dim=2) # AVG pooled tokens (B, N, F)
206
+ ], dim=-1) # (B, N, 2F)
207
+ features = self.feature_combiner(features, grid_size) # (B, T', F')
208
+ return features
209
+
210
+ def forward(self, x, grid_size: tuple[int, int, int] | None = None):
211
+ features = self.extract_backbone_features(x)
212
+ if self.feature_combiner is not None:
213
+ assert grid_size is not None, \
214
+ "`grid_size` must be provided when using feature combiner"
215
+ features = self.combine_features(features, grid_size)
216
+ return features
217
+
218
+ @classmethod
219
+ def from_config(
220
+ cls,
221
+ config: dict,
222
+ **kwargs,
223
+ ) -> 'SpectreImageFeatureExtractor':
224
+
225
+ model = cls(
226
+ backbone_name=config["backbone"],
227
+ backbone_checkpoint_path_or_url=config.get("backbone_checkpoint_path_or_url", None),
228
+ backbone_kwargs=config.get("backbone_kwargs", {}),
229
+ feature_combiner_name=config.get("feature_combiner", None),
230
+ feature_combiner_checkpoint_path_or_url=config.get("feature_combiner_checkpoint_path_or_url", None),
231
+ feature_combiner_kwargs=config.get("feature_combiner_kwargs", {}),
232
+ **kwargs,
233
+ )
234
+ return model
spectre/models/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vision_transformer import (
2
+ VisionTransformer,
3
+ vit_tiny_patch16_128,
4
+ vit_small_patch16_128,
5
+ vit_base_patch16_128,
6
+ vit_base_patch32_128,
7
+ vit_large_patch16_128,
8
+ vit_large_patch32_128,
9
+ )
10
+ from .vision_transformer_features import (
11
+ FeatureVisionTransformer,
12
+ feat_vit_tiny,
13
+ feat_vit_small,
14
+ feat_vit_base,
15
+ feat_vit_large,
16
+ )
17
+ from .resnet import (
18
+ ResNet,
19
+ resnet18,
20
+ resnet34,
21
+ resnet50,
22
+ resnet101,
23
+ resnext50,
24
+ resnext101,
25
+ )
26
+ from .eomt import EoMT
27
+ from .seomt import SEoMT
28
+ from .upsample_anything import UPA
29
+
30
+ __all__ = [
31
+ 'VisionTransformer',
32
+ 'vit_tiny_patch16_128',
33
+ 'vit_small_patch16_128',
34
+ 'vit_base_patch16_128',
35
+ 'vit_base_patch32_128',
36
+ 'vit_large_patch16_128',
37
+ 'vit_large_patch32_128',
38
+ 'FeatureVisionTransformer',
39
+ 'feat_vit_tiny',
40
+ 'feat_vit_small',
41
+ 'feat_vit_base',
42
+ 'feat_vit_large',
43
+ 'ResNet',
44
+ 'resnet18',
45
+ 'resnet34',
46
+ 'resnet50',
47
+ 'resnet101',
48
+ 'resnext50',
49
+ 'resnext101',
50
+ 'EoMT',
51
+ 'UPA',
52
+ 'SEoMT',
53
+ ]
spectre/models/eomt.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/tue-mps/eomt/
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from typing import Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from spectre.models.layers import LayerNorm3d
12
+
13
+
14
+ class ScaleBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim: int,
18
+ scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
19
+ conv1_layer: nn.Module = nn.ConvTranspose3d,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.conv1 = conv1_layer(
24
+ embed_dim,
25
+ embed_dim,
26
+ kernel_size=scale_factors,
27
+ stride=scale_factors,
28
+ )
29
+ self.act = nn.GELU()
30
+ self.conv2 = nn.Conv3d(
31
+ embed_dim,
32
+ embed_dim,
33
+ kernel_size=3,
34
+ padding=1,
35
+ groups=embed_dim,
36
+ bias=False,
37
+ )
38
+ self.norm = LayerNorm3d(embed_dim)
39
+
40
+ def forward(self, x):
41
+ x = self.conv1(x)
42
+ x = self.act(x)
43
+ x = self.conv2(x)
44
+ x = self.norm(x)
45
+
46
+ return x
47
+
48
+
49
+ def compute_upscale_stages(patch_size, min_size=4):
50
+ # Compute how many times to upscale per dimension
51
+ num_stages = []
52
+ for size in patch_size:
53
+ stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
54
+ num_stages.append(stages)
55
+ return num_stages
56
+
57
+
58
+ class EoMT(nn.Module):
59
+ def __init__(
60
+ self,
61
+ backbone: "VisionTransformer",
62
+ num_classes: int,
63
+ num_q: int,
64
+ num_blocks=4,
65
+ masked_attn_enabled=True,
66
+ ):
67
+ super().__init__()
68
+ self.backbone = backbone
69
+ self.num_q = num_q
70
+ self.num_blocks = num_blocks
71
+ self.masked_attn_enabled = masked_attn_enabled
72
+
73
+ self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
74
+
75
+ self.q = nn.Embedding(num_q, self.backbone.embed_dim)
76
+
77
+ self.class_head = nn.Linear(self.backbone.embed_dim, num_classes + 1)
78
+
79
+ self.mask_head = nn.Sequential(
80
+ nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
81
+ nn.GELU(),
82
+ nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
83
+ nn.GELU(),
84
+ nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
85
+ )
86
+
87
+ patch_size = self.backbone.patch_embed.patch_size
88
+ num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
89
+
90
+ # Build per-stage scale factors list
91
+ max_stages = max(num_upscale_stages)
92
+ upscale_blocks = []
93
+ for stage_idx in range(max_stages):
94
+ # for each dimension, upscale by 2 only if this dimension still has
95
+ # remaining upscales at this stage
96
+ scale_factors = tuple(
97
+ 2 if stage_idx < num_upscale_stages[dim] else 1
98
+ for dim in range(len(patch_size))
99
+ )
100
+ upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
101
+ scale_factors=scale_factors))
102
+
103
+ self.upscale = nn.Sequential(*upscale_blocks)
104
+
105
+ def _predict(self, x: torch.Tensor):
106
+ q = x[:, : self.num_q, :]
107
+
108
+ class_logits = self.class_head(q)
109
+
110
+ x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
111
+ x = x.transpose(1, 2).reshape(
112
+ x.shape[0], -1, *self.backbone.patch_embed.grid_size
113
+ )
114
+
115
+ mask_logits = torch.einsum(
116
+ "bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
117
+ )
118
+
119
+ return mask_logits, class_logits
120
+
121
+ @torch.compiler.disable
122
+ def _disable_attn_mask(self, attn_mask, prob):
123
+ if prob < 1:
124
+ random_queries = (
125
+ torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
126
+ > prob
127
+ )
128
+ attn_mask[
129
+ :, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
130
+ ][random_queries] = True
131
+
132
+ return attn_mask
133
+
134
+ def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope=None):
135
+ B, N, C = x.shape
136
+
137
+ q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
138
+ kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
139
+ k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
140
+ q, k = module.q_norm(q), module.k_norm(k)
141
+
142
+ if mask is not None:
143
+ mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
144
+
145
+ dropout_p = module.attn_drop.p if self.training else 0.0
146
+
147
+ if rope is not None:
148
+ if isinstance(rope, list):
149
+ rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
150
+ q, k = module.apply_rotary_pos_emb(q, k, rope)
151
+
152
+ if module.fused_attn:
153
+ x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
154
+ else:
155
+ attn = (q @ k.transpose(-2, -1)) * module.scale
156
+ if mask is not None:
157
+ attn = attn.masked_fill(~mask, float("-inf"))
158
+ attn = F.softmax(attn, dim=-1)
159
+ attn = module.attn_drop(attn)
160
+ x = attn @ v
161
+
162
+ x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
163
+
164
+ return x
165
+
166
+ def forward(self, x: torch.Tensor):
167
+ x = self.backbone.patch_embed(x)
168
+ x, rope = self.backbone._pos_embed(x)
169
+ x = self.backbone.patch_drop(x)
170
+ x = self.backbone.norm_pre(x)
171
+
172
+ attn_mask = None
173
+ mask_logits_per_layer, class_logits_per_layer = [], []
174
+
175
+ for i, block in enumerate(self.backbone.blocks):
176
+ if i == len(self.backbone.blocks) - self.num_blocks:
177
+ x = torch.cat(
178
+ (self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
179
+ )
180
+
181
+ if (
182
+ self.masked_attn_enabled
183
+ and i >= len(self.backbone.blocks) - self.num_blocks
184
+ ):
185
+ mask_logits, class_logits = self._predict(self.backbone.norm(x))
186
+ mask_logits_per_layer.append(mask_logits)
187
+ class_logits_per_layer.append(class_logits)
188
+
189
+ attn_mask = torch.ones(
190
+ x.shape[0],
191
+ x.shape[1],
192
+ x.shape[1],
193
+ dtype=torch.bool,
194
+ device=x.device,
195
+ )
196
+ interpolated = F.interpolate(
197
+ mask_logits,
198
+ self.backbone.patch_embed.grid_size,
199
+ mode="trilinear",
200
+ )
201
+ interpolated = interpolated.view(
202
+ interpolated.size(0), interpolated.size(1), -1
203
+ )
204
+ attn_mask[
205
+ :,
206
+ : self.num_q,
207
+ self.num_q + self.backbone.num_prefix_tokens :,
208
+ ] = (
209
+ interpolated > 0
210
+ )
211
+ attn_mask = self._disable_attn_mask(
212
+ attn_mask,
213
+ self.attn_mask_probs[
214
+ i - len(self.backbone.blocks) + self.num_blocks
215
+ ],
216
+ )
217
+
218
+ x = x + block.drop_path1(
219
+ block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope))
220
+ )
221
+ x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
222
+
223
+ mask_logits, class_logits = self._predict(self.backbone.norm(x))
224
+ mask_logits_per_layer.append(mask_logits)
225
+ class_logits_per_layer.append(class_logits)
226
+
227
+ return (
228
+ mask_logits_per_layer,
229
+ class_logits_per_layer,
230
+ )
spectre/models/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .patch_embed import PatchEmbed
2
+ from .attention import Attention
3
+ from .layernorm import LayerNorm3d
4
+ from .rotary_pos_embed import RotaryPositionEmbedding
5
+
6
+ __all__ = [
7
+ 'PatchEmbed',
8
+ 'Attention',
9
+ 'LayerNorm3d',
10
+ 'RotaryPositionEmbedding',
11
+ ]
spectre/models/layers/attention.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.jit import Final
7
+ from timm.layers import use_fused_attn
8
+
9
+ from spectre.models.layers.rotary_pos_embed import rope_apply
10
+
11
+
12
+ class Attention(nn.Module):
13
+ fused_attn: Final[bool]
14
+
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ num_heads: int = 8,
19
+ mode: str = "mha",
20
+ q_proj_dim: Optional[int] = None,
21
+ kv_proj_dim: Optional[int] = None,
22
+ qkv_bias: bool = False,
23
+ qk_norm: bool = False,
24
+ proj_bias: bool = True,
25
+ attn_drop: float = 0.,
26
+ proj_drop: float = 0.,
27
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
28
+ ) -> None:
29
+ super().__init__()
30
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
31
+ self.num_heads = num_heads
32
+ self.head_dim = dim // num_heads
33
+ self.scale = self.head_dim ** -0.5
34
+ self.fused_attn = use_fused_attn()
35
+ self.mode = mode.lower()
36
+ assert self.mode in ["mha", "mqa", "mla"], "Attention mode must be 'mha', 'mqa', or 'mla'"
37
+ assert not (self.mode == "mla" and kv_proj_dim is None), "kv_proj_dim must be provided for 'mla' mode"
38
+ assert not (self.mode == "mla" and q_proj_dim is None), "q_proj_dim must be provided for 'mla' mode"
39
+
40
+ if self.mode == "mha":
41
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
42
+ self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ elif self.mode == "mqa":
46
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
47
+ self.kv = nn.Linear(dim, 2 * self.head_dim, bias=qkv_bias) # Key and value pair shared across heads
48
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
49
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
50
+ elif self.mode == "mla":
51
+ self.q_proj = nn.Linear(dim, q_proj_dim, bias=qkv_bias) # Projected query for every head
52
+ self.kv_proj = nn.Linear(dim, kv_proj_dim, bias=qkv_bias) # Projected key and value pair for every head
53
+ self.q_norm = norm_layer(q_proj_dim) if qk_norm else nn.Identity()
54
+ self.kv_norm = norm_layer(kv_proj_dim) if qk_norm else nn.Identity()
55
+ self.q = nn.Linear(q_proj_dim, dim, bias=qkv_bias) # Query for every head
56
+ self.kv = nn.Linear(kv_proj_dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head
57
+
58
+ self.attn_drop = nn.Dropout(attn_drop)
59
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
60
+ self.proj_drop = nn.Dropout(proj_drop)
61
+
62
+ def apply_rotary_pos_emb(
63
+ self,
64
+ q: torch.Tensor,
65
+ k: torch.Tensor,
66
+ rope: Tuple[torch.Tensor, torch.Tensor],
67
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ """
69
+ Apply RoPE to the query and key tensors.
70
+ Args:
71
+ q (torch.Tensor): Query tensor of shape (B, num_heads, N, head_dim)
72
+ k (torch.Tensor): Key tensor of shape (B, num_heads, N, head_dim)
73
+ rope (Tuple[torch.Tensor, torch.Tensor]): Tuple of (sin, cos) tensors for RoPE application.
74
+ Sin and cos can be of shape
75
+ """
76
+
77
+ # Match dtype to rope for numeric stability
78
+ q_dtype, k_dtype = q.dtype, k.dtype
79
+ sin, cos = rope
80
+
81
+ if sin.ndim == 2:
82
+ sin = sin.unsqueeze(0).unsqueeze(0)
83
+ cos = cos.unsqueeze(0).unsqueeze(0)
84
+ elif sin.ndim == 3:
85
+ sin = sin.unsqueeze(1)
86
+ cos = cos.unsqueeze(1)
87
+ else:
88
+ raise ValueError("RoPE sin/cos must be of shape [N, head_dim] or [B, N, head_dim]")
89
+
90
+ rope_dtype = sin.dtype
91
+
92
+ q = q.to(dtype=rope_dtype)
93
+ k = k.to(dtype=rope_dtype)
94
+
95
+ N = q.shape[-2] # total tokens per sample
96
+ N_spatial = sin.shape[-2] # number of spatial tokens covered by rope
97
+ prefix = N - N_spatial # e.g., [cls] or [reg] tokens at the front
98
+ assert prefix >= 0, "RoPE sin/cos length exceeds sequence length"
99
+
100
+ if prefix > 0:
101
+ q_prefix = q[:, :, :prefix, :]
102
+ k_prefix = k[:, :, :prefix, :]
103
+ q_spatial = q[:, :, prefix:, :]
104
+ k_spatial = k[:, :, prefix:, :]
105
+ else:
106
+ q_prefix = k_prefix = None
107
+ q_spatial, k_spatial = q, k
108
+
109
+ # Apply RoPE on the spatial tail
110
+ q_spatial = rope_apply(q_spatial, sin, cos)
111
+ k_spatial = rope_apply(k_spatial, sin, cos)
112
+
113
+ # Stitch back
114
+ if prefix > 0:
115
+ q = torch.cat((q_prefix, q_spatial), dim=-2)
116
+ k = torch.cat((k_prefix, k_spatial), dim=-2)
117
+ else:
118
+ q, k = q_spatial, k_spatial
119
+
120
+ # Cast back to original dtypes
121
+ q = q.to(dtype=q_dtype)
122
+ k = k.to(dtype=k_dtype)
123
+
124
+ return q, k
125
+
126
+ def compute_qkv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
127
+ B, N, _ = x.shape
128
+ if self.mode == "mha":
129
+ q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
130
+ kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
131
+ k, v = kv.unbind(0)
132
+ q, k = self.q_norm(q), self.k_norm(k)
133
+ elif self.mode == "mqa":
134
+ q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
135
+ kv = self.kv(x).reshape(B, N, 2, 1, self.head_dim).permute(2, 0, 3, 1, 4)
136
+ kv = kv.expand(-1, -1, self.num_heads, -1, -1) # Expand to match num_heads
137
+ k, v = kv.unbind(0)
138
+ q, k = self.q_norm(q), self.k_norm(k)
139
+ elif self.mode == "mla":
140
+ q = self.q_proj(x)
141
+ kv = self.kv_proj(x)
142
+ q, kv = self.q_norm(q), self.kv_norm(kv) # Normalization on projections
143
+ q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
144
+ kv = self.kv(kv).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
145
+ k, v = kv.unbind(0)
146
+ return q, k, v
147
+
148
+ def compute_attention(
149
+ self,
150
+ q: torch.Tensor,
151
+ k: torch.Tensor,
152
+ v: torch.Tensor,
153
+ ) -> torch.Tensor:
154
+ B, _, N, _ = q.shape
155
+ C = self.num_heads * self.head_dim
156
+
157
+ if self.fused_attn:
158
+ x = F.scaled_dot_product_attention(
159
+ q, k, v,
160
+ dropout_p=self.attn_drop.p if self.training else 0.,
161
+ )
162
+ else:
163
+ q = q * self.scale
164
+ attn = q @ k.transpose(-2, -1)
165
+ attn = attn.softmax(dim=-1)
166
+ attn = self.attn_drop(attn)
167
+ x = attn @ v
168
+
169
+ return x.transpose(1, 2).reshape(B, N, C)
170
+
171
+ def forward(
172
+ self,
173
+ x: torch.Tensor,
174
+ rope = None,
175
+ ) -> torch.Tensor:
176
+
177
+ q, k, v = self.compute_qkv(x)
178
+
179
+ if rope is not None:
180
+ if isinstance(rope, list):
181
+ rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
182
+ q, k = self.apply_rotary_pos_emb(q, k, rope)
183
+
184
+ x = self.compute_attention(q, k, v)
185
+ x = self.proj(x)
186
+ x = self.proj_drop(x)
187
+ return x
spectre/models/layers/layernorm.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.layers.fast_norm import is_fast_norm, fast_layer_norm
5
+
6
+
7
+ class LayerNorm3d(nn.LayerNorm):
8
+ """ LayerNorm for channels of '3D' spatial NCHWD tensors """
9
+ _fast_norm: torch.jit.Final[bool]
10
+
11
+ def __init__(self, num_channels, eps=1e-6, affine=True):
12
+ super().__init__(num_channels, eps=eps, elementwise_affine=affine)
13
+ self._fast_norm = is_fast_norm() # Assuming is_fast_norm() is defined somewhere
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ x = x.permute(0, 2, 3, 4, 1) # Permute to NCDHW format
17
+ if self._fast_norm:
18
+ x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
19
+ else:
20
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
21
+ x = x.permute(0, 4, 1, 2, 3) # Permute back to NCHWD format
22
+ return x
23
+
spectre/models/layers/patch_embed.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union, Callable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from spectre.utils import (
9
+ to_3tuple,
10
+ resample_patch_embed,
11
+ Format,
12
+ nchwd_to,
13
+ )
14
+
15
+
16
+ class PatchEmbed(nn.Module):
17
+ """ 3D Image to Patch Embedding
18
+ """
19
+ output_fmt: Format
20
+ dynamic_img_pad: torch.jit.Final[bool]
21
+
22
+ def __init__(
23
+ self,
24
+ img_size: Optional[Union[int, Tuple[int, int, int]]] = (128, 128, 64),
25
+ patch_size: Union[int, Tuple[int, int, int]] = (16, 16, 8),
26
+ in_chans: int = 1,
27
+ embed_dim: int = 768,
28
+ norm_layer: Optional[Callable] = None,
29
+ flatten: bool = True,
30
+ output_fmt: Optional[str] = None,
31
+ bias: bool = True,
32
+ strict_img_size: bool = True,
33
+ dynamic_img_pad: bool = False,
34
+ ):
35
+ super().__init__()
36
+ self.patch_size = to_3tuple(patch_size)
37
+ self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
38
+
39
+ if output_fmt is not None:
40
+ self.flatten = False
41
+ self.output_fmt = Format(output_fmt)
42
+ else:
43
+ # flatten spatial dim and transpose to channels last, kept for bwd compat
44
+ self.flatten = flatten
45
+ self.output_fmt = Format.NCHWD
46
+ self.strict_img_size = strict_img_size
47
+ self.dynamic_img_pad = dynamic_img_pad
48
+
49
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
50
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
51
+
52
+ def _init_img_size(self, img_size: Union[int, Tuple[int, int, int]]):
53
+ assert self.patch_size
54
+ if img_size is None:
55
+ return None, None, None
56
+ img_size = to_3tuple(img_size)
57
+ grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
58
+ num_patches = grid_size[0] * grid_size[1] * grid_size[2]
59
+ return img_size, grid_size, num_patches
60
+
61
+ def set_input_size(
62
+ self,
63
+ img_size: Optional[Union[int, Tuple[int, int, int]]] = None,
64
+ patch_size: Optional[Union[int, Tuple[int, int, int]]] = None,
65
+ ):
66
+ new_patch_size = None
67
+ if patch_size is not None:
68
+ new_patch_size = to_3tuple(patch_size)
69
+ if new_patch_size is not None and new_patch_size != self.patch_size:
70
+ with torch.no_grad():
71
+ new_proj = nn.Conv3d(
72
+ self.proj.in_channels,
73
+ self.proj.out_channels,
74
+ kernel_size=new_patch_size,
75
+ stride=new_patch_size,
76
+ bias=self.proj.bias is not None,
77
+ )
78
+ new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
79
+ if self.proj.bias is not None:
80
+ new_proj.bias.copy_(self.proj.bias)
81
+ self.proj = new_proj
82
+ self.patch_size = new_patch_size
83
+ img_size = img_size or self.img_size
84
+ if img_size != self.img_size or new_patch_size is not None:
85
+ self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
86
+
87
+ def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int, int], int]:
88
+ if as_scalar:
89
+ return max(self.patch_size)
90
+ else:
91
+ return self.patch_size
92
+
93
+ def dynamic_feat_size(self, img_size: Tuple[int, int, int]) -> Tuple[int, int, int]:
94
+ """ Get grid (feature) size for given image size taking account of dynamic padding.
95
+ NOTE: must be torchscript compatible so using fixed tuple indexing
96
+ """
97
+ if self.dynamic_img_pad:
98
+ return (
99
+ math.ceil(img_size[0] / self.patch_size[0]),
100
+ math.ceil(img_size[1] / self.patch_size[1]),
101
+ math.ceil(img_size[2] / self.patch_size[2]),
102
+ )
103
+ else:
104
+ return (
105
+ img_size[0] // self.patch_size[0],
106
+ img_size[1] // self.patch_size[1],
107
+ img_size[2] // self.patch_size[2],
108
+ )
109
+
110
+ def forward(self, x):
111
+ _, _, H, W, D = x.shape
112
+ if self.img_size is not None:
113
+ if self.strict_img_size:
114
+ assert H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]})."
115
+ assert W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]})."
116
+ assert D == self.img_size[2], f"Input depth ({D}) doesn't match model ({self.img_size[2]})."
117
+ elif not self.dynamic_img_pad:
118
+ assert H % self.patch_size[0] == 0, \
119
+ f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
120
+ assert W % self.patch_size[1] == 0, \
121
+ f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
122
+ assert D % self.patch_size[2] == 0, \
123
+ f"Input depth ({D}) should be divisible by patch size ({self.patch_size[2]})."
124
+ if self.dynamic_img_pad:
125
+ pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
126
+ pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
127
+ pad_d = (self.patch_size[2] - D % self.patch_size[2]) % self.patch_size[2]
128
+ x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
129
+ x = self.proj(x)
130
+ if self.flatten:
131
+ x = x.flatten(2).transpose(1, 2) # NCHWD -> NLC
132
+ elif self.output_fmt != Format.NCHWD:
133
+ x = nchwd_to(x, self.output_fmt)
134
+ x = self.norm(x)
135
+ return x
spectre/models/layers/rotary_pos_embed.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Literal, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+
9
+ def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
10
+ # x: [..., D], split into halves and rotate
11
+ x1, x2 = x.chunk(2, dim=-1)
12
+ return torch.cat([-x2, x1], dim=-1)
13
+
14
+
15
+ def rope_apply(
16
+ x: torch.Tensor,
17
+ sin: torch.Tensor,
18
+ cos: torch.Tensor
19
+ ) -> torch.Tensor:
20
+ # x, sin, cos: [..., D]
21
+ return (x * cos) + (rope_rotate_half(x) * sin)
22
+
23
+
24
+ class RotaryPositionEmbedding(nn.Module):
25
+ """
26
+ 3D Rotary Positional Embedding (RoPE) with no mixing across axes (axial),
27
+ and no learnable weights. Allows for shifting and scaling of the positional encodings
28
+ for improving performance on varying resolutions.
29
+ Mirrors DINOv3 style but for (H, W, D).
30
+
31
+ Requirements:
32
+ - head_dim % 6 == 0 (because 3 axes -> periods of size head_dim//6, then we tile to fill head_dim)
33
+
34
+ Two parametrizations:
35
+ * base
36
+ * min_period + max_period
37
+ """
38
+ def __init__(
39
+ self,
40
+ embed_dim: int,
41
+ *,
42
+ num_heads: int,
43
+ base: float | None = 1000.0, # works for common 8^3=512 to 16^3=4096 tokens
44
+ min_period: float | None = None,
45
+ max_period: float | None = None,
46
+ normalize_coords: Literal["min", "max", "separate"] = "separate",
47
+ shift_coords: float | None = None,
48
+ jitter_coords: float | None = None,
49
+ rescale_coords: float | None = None,
50
+ dtype: torch.dtype | None = None,
51
+ device: torch.device | None = None,
52
+ ):
53
+ super().__init__()
54
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
55
+ head_dim = embed_dim // num_heads
56
+ assert head_dim % 6 == 0, "For 3D RoPE, (embed_dim // num_heads) must be divisible by 6"
57
+
58
+ both_periods = (min_period is not None) and (max_period is not None)
59
+ if (base is None and not both_periods) or (base is not None and both_periods):
60
+ raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
61
+
62
+ self.base = base
63
+ self.min_period = min_period
64
+ self.max_period = max_period
65
+ self.head_dim = head_dim
66
+ self.normalize_coords = normalize_coords
67
+ self.shift_coords = shift_coords
68
+ self.jitter_coords = jitter_coords
69
+ self.rescale_coords = rescale_coords
70
+
71
+ # Keep dtype persistent so teacher can be initialized from student state_dict()
72
+ self.dtype = dtype
73
+ self.register_buffer(
74
+ "periods",
75
+ torch.empty(self.head_dim // 6, device=device, dtype=dtype),
76
+ persistent=True,
77
+ )
78
+ self._init_weights()
79
+
80
+ @torch.no_grad()
81
+ def _init_weights(self):
82
+ device = self.periods.device
83
+ dtype = self.dtype
84
+ if self.base is not None:
85
+ # powers from 0..(head_dim // 3 - 1), normalized to [0, 1) across head_dim // 3?
86
+ # for 3D we use // 6 per axis
87
+ periods = self.base ** (
88
+ 2 * torch.arange(self.head_dim // 6, device=device, dtype=dtype) / (self.head_dim // 3)
89
+ )
90
+ else:
91
+ # geometric spacing between min_period and max_period
92
+ base = self.max_period / self.min_period
93
+ exponents = torch.linspace(0, 1, self.head_dim // 6, device=device, dtype=dtype)
94
+ periods = base ** exponents
95
+ periods = periods / base
96
+ periods = periods * self.max_period
97
+ self.periods.data = periods
98
+
99
+ def forward(self, *, H: int, W: int, D: int) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ """
101
+ Returns:
102
+ sin, cos: [H * W * D, head_dim] (per-head)
103
+ """
104
+ device = self.periods.device
105
+ dtype = self.dtype
106
+ dd = dict(device=device, dtype=dtype)
107
+
108
+ # Prepare coords in [0, 1] then map to [-1, +1]
109
+ if self.normalize_coords == "max":
110
+ max_dim = max(H, W, D)
111
+ coords_h = torch.arange(0.5, H, **dd) / max_dim
112
+ coords_w = torch.arange(0.5, W, **dd) / max_dim
113
+ coords_d = torch.arange(0.5, D, **dd) / max_dim
114
+ elif self.normalize_coords == "min":
115
+ min_dim = min(H, W, D)
116
+ coords_h = torch.arange(0.5, H, **dd) / min_dim
117
+ coords_w = torch.arange(0.5, W, **dd) / min_dim
118
+ coords_d = torch.arange(0.5, D, **dd) / min_dim
119
+ elif self.normalize_coords == "separate":
120
+ coords_h = torch.arange(0.5, H, **dd) / H
121
+ coords_w = torch.arange(0.5, W, **dd) / W
122
+ coords_d = torch.arange(0.5, D, **dd) / D
123
+ else:
124
+ raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
125
+
126
+ coords = torch.stack(
127
+ torch.meshgrid(coords_h, coords_w, coords_d, indexing="ij"),
128
+ dim=-1
129
+ ) # [H, W, D, 3]
130
+ coords = coords.flatten(0, 2) # [HWD, 3]
131
+ coords = 2.0 * coords - 1.0 # [-1, +1]
132
+
133
+ # Optional train-time augmentations on coords (DINOv3)
134
+ if self.training and self.shift_coords is not None:
135
+ shift_hwd = torch.empty(3, **dd).uniform_(-self.shift_coords, self.shift_coords)
136
+ coords = coords + shift_hwd[None, :]
137
+
138
+ if self.training and self.jitter_coords is not None:
139
+ jit_max = np.log(self.jitter_coords); jit_min = -jit_max
140
+ jitter = torch.empty(3, **dd).uniform_(jit_min, jit_max).exp()
141
+ coords = coords * jitter[None, :]
142
+
143
+ if self.training and self.rescale_coords is not None:
144
+ r_max = np.log(self.rescale_coords); r_min = -r_max
145
+ rescale = torch.empty(1, **dd).uniform_(r_min, r_max).exp()
146
+ coords = coords * rescale
147
+
148
+ # --- Build angles per axis, then concatenate across axes ---
149
+ # coords: [N, 3] ; periods: [head_dim // 6]
150
+ # angles: [N, 3, head_dim // 6] -> flatten(1, 2) -> [N, head_dim // 2] -> tile(2) -> [N, head_dim]
151
+ angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [N, 3, head_dim // 6]
152
+ angles = angles.flatten(1, 2) # [N, head_dim // 2]
153
+ angles = angles.tile(2) # [N, head_dim]
154
+
155
+ cos = torch.cos(angles)
156
+ sin = torch.sin(angles)
157
+ return sin, cos
spectre/models/resnet.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from urllib.parse import urlparse
4
+ from typing import Type, Any, Tuple, List, Optional, Union, Dict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from spectre.utils import to_ntuple
11
+
12
+
13
+ def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
14
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
15
+ return padding
16
+
17
+
18
+ class BasicBlock(nn.Module):
19
+ expansion = 1
20
+
21
+ def __init__(
22
+ self,
23
+ inplanes: int,
24
+ planes: int,
25
+ stride: int = 1,
26
+ downsample: Optional[nn.Module] = None,
27
+ cardinality: int = 1,
28
+ base_width: int = 64,
29
+ reduce_first: int = 1,
30
+ dilation: int = 1,
31
+ first_dilation: Optional[int] = None,
32
+ act_layer: Type[nn.Module] = nn.ReLU,
33
+ norm_layer: Type[nn.Module] = nn.BatchNorm3d,
34
+ ):
35
+ """
36
+ Args:
37
+ inplanes: Input channel dimensionality.
38
+ planes: Used to determine output channel dimensionalities.
39
+ stride: Stride used in convolution layers.
40
+ downsample: Optional downsample layer for residual path.
41
+ cardinality: Number of convolution groups.
42
+ base_width: Base width used to determine output channel dimensionality.
43
+ reduce_first: Reduction factor for first convolution output width of residual blocks.
44
+ dilation: Dilation rate for convolution layers.
45
+ first_dilation: Dilation rate for first convolution layer.
46
+ act_layer: Activation layer.
47
+ norm_layer: Normalization layer.
48
+ """
49
+ super(BasicBlock, self).__init__()
50
+
51
+ assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
52
+ assert base_width == 64, 'BasicBlock does not support changing base width'
53
+ first_planes = planes // reduce_first
54
+ outplanes = planes * self.expansion
55
+ first_dilation = first_dilation or dilation
56
+
57
+ self.conv1 = nn.Conv3d(
58
+ inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
59
+ dilation=first_dilation, bias=False)
60
+ self.bn1 = norm_layer(first_planes)
61
+ self.act1 = act_layer()
62
+
63
+ self.conv2 = nn.Conv3d(
64
+ first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
65
+ self.bn2 = norm_layer(outplanes)
66
+
67
+ self.act2 = act_layer()
68
+ self.downsample = downsample
69
+ self.stride = stride
70
+ self.dilation = dilation
71
+
72
+ def zero_init_last(self):
73
+ if getattr(self.bn2, 'weight', None) is not None:
74
+ nn.init.zeros_(self.bn2.weight)
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ shortcut = x
78
+
79
+ x = self.conv1(x)
80
+ x = self.bn1(x)
81
+ x = self.act1(x)
82
+
83
+ x = self.conv2(x)
84
+ x = self.bn2(x)
85
+
86
+ if self.downsample is not None:
87
+ shortcut = self.downsample(shortcut)
88
+ x = x + shortcut
89
+ x = self.act2(x)
90
+
91
+ return x
92
+
93
+
94
+ class Bottleneck(nn.Module):
95
+ expansion = 4
96
+
97
+ def __init__(
98
+ self,
99
+ inplanes: int,
100
+ planes: int,
101
+ stride: int = 1,
102
+ downsample: Optional[nn.Module] = None,
103
+ cardinality: int = 1,
104
+ base_width: int = 64,
105
+ reduce_first: int = 1,
106
+ dilation: int = 1,
107
+ first_dilation: Optional[int] = None,
108
+ act_layer: Type[nn.Module] = nn.ReLU,
109
+ norm_layer: Type[nn.Module] = nn.BatchNorm3d,
110
+ ):
111
+ """
112
+ Args:
113
+ inplanes: Input channel dimensionality.
114
+ planes: Used to determine output channel dimensionalities.
115
+ stride: Stride used in convolution layers.
116
+ downsample: Optional downsample layer for residual path.
117
+ cardinality: Number of convolution groups.
118
+ base_width: Base width used to determine output channel dimensionality.
119
+ reduce_first: Reduction factor for first convolution output width of residual blocks.
120
+ dilation: Dilation rate for convolution layers.
121
+ first_dilation: Dilation rate for first convolution layer.
122
+ act_layer: Activation layer.
123
+ norm_layer: Normalization layer.
124
+ """
125
+ super(Bottleneck, self).__init__()
126
+
127
+ width = int(math.floor(planes * (base_width / 64)) * cardinality)
128
+ first_planes = width // reduce_first
129
+ outplanes = planes * self.expansion
130
+ first_dilation = first_dilation or dilation
131
+
132
+ self.conv1 = nn.Conv3d(inplanes, first_planes, kernel_size=1, bias=False)
133
+ self.bn1 = norm_layer(first_planes)
134
+ self.act1 = act_layer()
135
+
136
+ self.conv2 = nn.Conv3d(
137
+ first_planes, width, kernel_size=3, stride=stride,
138
+ padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
139
+ self.bn2 = norm_layer(width)
140
+ self.act2 = act_layer()
141
+
142
+ self.conv3 = nn.Conv3d(width, outplanes, kernel_size=1, bias=False)
143
+ self.bn3 = norm_layer(outplanes)
144
+
145
+ self.act3 = act_layer()
146
+ self.downsample = downsample
147
+ self.stride = stride
148
+ self.dilation = dilation
149
+
150
+ def zero_init_last(self):
151
+ if getattr(self.bn3, 'weight', None) is not None:
152
+ nn.init.zeros_(self.bn3.weight)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ shortcut = x
156
+
157
+ x = self.conv1(x)
158
+ x = self.bn1(x)
159
+ x = self.act1(x)
160
+
161
+ x = self.conv2(x)
162
+ x = self.bn2(x)
163
+ x = self.act2(x)
164
+
165
+ x = self.conv3(x)
166
+ x = self.bn3(x)
167
+
168
+ if self.downsample is not None:
169
+ shortcut = self.downsample(shortcut)
170
+ x = x + shortcut
171
+ x = self.act3(x)
172
+
173
+ return x
174
+
175
+
176
+ def downsample_conv(
177
+ in_channels: int,
178
+ out_channels: int,
179
+ kernel_size: int,
180
+ stride: int = 1,
181
+ dilation: int = 1,
182
+ first_dilation: Optional[int] = None,
183
+ norm_layer: Optional[Type[nn.Module]] = None,
184
+ ) -> nn.Module:
185
+ norm_layer = norm_layer or nn.BatchNorm3d
186
+ kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
187
+ first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
188
+ p = get_padding(kernel_size, stride, first_dilation)
189
+
190
+ return nn.Sequential(*[
191
+ nn.Conv3d(
192
+ in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
193
+ norm_layer(out_channels)
194
+ ])
195
+
196
+
197
+ def downsample_avg(
198
+ in_channels: int,
199
+ out_channels: int,
200
+ kernel_size: int,
201
+ stride: int = 1,
202
+ dilation: int = 1,
203
+ first_dilation: Optional[int] = None,
204
+ norm_layer: Optional[Type[nn.Module]] = None,
205
+ ) -> nn.Module:
206
+ norm_layer = norm_layer or nn.BatchNorm3d
207
+ avg_stride = stride if dilation == 1 else 1
208
+ if stride == 1 and dilation == 1:
209
+ pool = nn.Identity()
210
+ else:
211
+ pool = nn.AvgPool3d(2, avg_stride, ceil_mode=True, count_include_pad=False)
212
+
213
+ return nn.Sequential(*[
214
+ pool,
215
+ nn.Conv3d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
216
+ norm_layer(out_channels)
217
+ ])
218
+
219
+
220
+ def make_blocks(
221
+ block_fns: Tuple[Union[BasicBlock, Bottleneck]],
222
+ channels: Tuple[int, ...],
223
+ block_repeats: Tuple[int, ...],
224
+ inplanes: int,
225
+ reduce_first: int = 1,
226
+ output_stride: int = 32,
227
+ down_kernel_size: int = 1,
228
+ avg_down: bool = False,
229
+ **kwargs,
230
+ ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
231
+ stages = []
232
+ feature_info = []
233
+ net_num_blocks = sum(block_repeats)
234
+ net_block_idx = 0
235
+ net_stride = 4
236
+ dilation = prev_dilation = 1
237
+ for stage_idx, (block_fn, planes, num_blocks) in enumerate(zip(block_fns, channels, block_repeats)):
238
+ stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
239
+ stride = 1 if stage_idx == 0 else 2
240
+ if net_stride >= output_stride:
241
+ dilation *= stride
242
+ stride = 1
243
+ else:
244
+ net_stride *= stride
245
+
246
+ downsample = None
247
+ if stride != 1 or inplanes != planes * block_fn.expansion:
248
+ down_kwargs = dict(
249
+ in_channels=inplanes,
250
+ out_channels=planes * block_fn.expansion,
251
+ kernel_size=down_kernel_size,
252
+ stride=stride,
253
+ dilation=dilation,
254
+ first_dilation=prev_dilation,
255
+ norm_layer=kwargs.get('norm_layer'),
256
+ )
257
+ downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
258
+
259
+ block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, **kwargs)
260
+ blocks = []
261
+ for block_idx in range(num_blocks):
262
+ downsample = downsample if block_idx == 0 else None
263
+ stride = stride if block_idx == 0 else 1
264
+ blocks.append(block_fn(
265
+ inplanes,
266
+ planes,
267
+ stride,
268
+ downsample,
269
+ first_dilation=prev_dilation,
270
+ **block_kwargs,
271
+ ))
272
+ prev_dilation = dilation
273
+ inplanes = planes * block_fn.expansion
274
+ net_block_idx += 1
275
+
276
+ stages.append((stage_name, nn.Sequential(*blocks)))
277
+ feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
278
+
279
+ return stages, feature_info
280
+
281
+
282
+ def feature_take_indices(
283
+ num_features: int,
284
+ indices: Optional[Union[int, List[int]]] = None,
285
+ as_set: bool = False,
286
+ ) -> Tuple[List[int], int]:
287
+ """ Determine the absolute feature indices to 'take' from.
288
+
289
+ Note: This function can be called in forward() so must be torchscript compatible,
290
+ which requires some incomplete typing and workaround hacks.
291
+
292
+ Args:
293
+ num_features: total number of features to select from
294
+ indices: indices to select,
295
+ None -> select all
296
+ int -> select last n
297
+ list/tuple of int -> return specified (-ve indices specify from end)
298
+ as_set: return as a set
299
+
300
+ Returns:
301
+ List (or set) of absolute (from beginning) indices, Maximum index
302
+ """
303
+ if indices is None:
304
+ indices = num_features # all features if None
305
+
306
+ if isinstance(indices, int):
307
+ # convert int -> last n indices
308
+ assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})'
309
+ take_indices = [num_features - indices + i for i in range(indices)]
310
+ else:
311
+ take_indices: List[int] = []
312
+ for i in indices:
313
+ idx = num_features + i if i < 0 else i
314
+ assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})'
315
+ take_indices.append(idx)
316
+
317
+ if not torch.jit.is_scripting() and as_set:
318
+ return set(take_indices), max(take_indices)
319
+
320
+ return take_indices, max(take_indices)
321
+
322
+
323
+ class ResNet(nn.Module):
324
+ """ResNet / ResNeXt
325
+
326
+ This class implements all variants of ResNet, ResNeXt that
327
+ * have > 1 stride in the 3x3 conv layer of bottleneck
328
+ * have conv-bn-act ordering
329
+
330
+ This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
331
+ variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
332
+ 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
333
+
334
+ ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
335
+ * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
336
+ * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
337
+ * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
338
+ * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
339
+ * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
340
+ * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
341
+ * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
342
+
343
+ ResNeXt
344
+ * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
345
+ * same c,d, e, s variants as ResNet can be enabled
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ block: Union[BasicBlock, Bottleneck],
351
+ layers: Tuple[int, ...],
352
+ num_classes: int = 1000,
353
+ in_chans: int = 1,
354
+ output_stride: int = 32,
355
+ global_pool: str = 'avg',
356
+ cardinality: int = 1,
357
+ base_width: int = 64,
358
+ stem_width: int = 64,
359
+ stem_type: str = '',
360
+ replace_stem_pool: bool = False,
361
+ block_reduce_first: int = 1,
362
+ down_kernel_size: int = 1,
363
+ avg_down: bool = False,
364
+ channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
365
+ act_layer: Type[nn.Module] = nn.ReLU,
366
+ norm_layer: Type[nn.Module] = nn.BatchNorm3d,
367
+ drop_rate: float = 0.0,
368
+ zero_init_last: bool = True,
369
+ block_args: Optional[Dict[str, Any]] = None,
370
+ ):
371
+ """
372
+ Args:
373
+ block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
374
+ layers (List[int]) : number of layers in each block
375
+ num_classes (int): number of classification classes (default 1000)
376
+ in_chans (int): number of input (color) channels. (default 3)
377
+ output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
378
+ global_pool (str): Global pooling type. One of 'avg', 'max' (default 'avg')
379
+ cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
380
+ base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
381
+ stem_width (int): number of channels in stem convolutions (default 64)
382
+ stem_type (str): The type of stem (default ''):
383
+ * '', default - a single 7x7 conv with a width of stem_width
384
+ * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
385
+ * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
386
+ replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
387
+ block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
388
+ 1 for all archs except senets, where 2 (default 1)
389
+ down_kernel_size (int): kernel size of residual block downsample path,
390
+ 1x1 for most, 3x3 for senets (default: 1)
391
+ avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
392
+ act_layer (str, nn.Module): activation layer
393
+ norm_layer (str, nn.Module): normalization layer
394
+ drop_rate (float): Dropout probability before classifier, for training (default 0.)
395
+ zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
396
+ block_args (dict): Extra kwargs to pass through to block module
397
+ """
398
+ super(ResNet, self).__init__()
399
+ block_args = block_args or dict()
400
+ assert output_stride in (8, 16, 32)
401
+ self.num_classes = num_classes
402
+ self.drop_rate = drop_rate
403
+
404
+ # Stem
405
+ deep_stem = 'deep' in stem_type
406
+ inplanes = stem_width * 2 if deep_stem else 64
407
+ if deep_stem:
408
+ stem_chs = (stem_width, stem_width)
409
+ if 'tiered' in stem_type:
410
+ stem_chs = (3 * (stem_width // 4), stem_width)
411
+ self.conv1 = nn.Sequential(*[
412
+ nn.Conv3d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
413
+ norm_layer(stem_chs[0]),
414
+ act_layer(),
415
+ nn.Conv3d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
416
+ norm_layer(stem_chs[1]),
417
+ act_layer(),
418
+ nn.Conv3d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
419
+ else:
420
+ self.conv1 = nn.Conv3d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
421
+ self.bn1 = norm_layer(inplanes)
422
+ self.act1 = act_layer()
423
+ self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
424
+
425
+ # Stem pooling. The name 'maxpool' remains for weight compatibility.
426
+ if replace_stem_pool:
427
+ self.maxpool = nn.Sequential(*filter(None, [
428
+ nn.Conv3d(inplanes, inplanes, 3, stride=2, padding=1, bias=False),
429
+ norm_layer(inplanes),
430
+ act_layer(),
431
+ ]))
432
+ else:
433
+ self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
434
+
435
+ # Feature Blocks
436
+ block_fns = to_ntuple(len(channels))(block)
437
+ stage_modules, stage_feature_info = make_blocks(
438
+ block_fns,
439
+ channels,
440
+ layers,
441
+ inplanes,
442
+ cardinality=cardinality,
443
+ base_width=base_width,
444
+ output_stride=output_stride,
445
+ reduce_first=block_reduce_first,
446
+ avg_down=avg_down,
447
+ down_kernel_size=down_kernel_size,
448
+ act_layer=act_layer,
449
+ norm_layer=norm_layer,
450
+ **block_args,
451
+ )
452
+ for stage in stage_modules:
453
+ self.add_module(*stage) # layer1, layer2, etc
454
+ self.feature_info.extend(stage_feature_info)
455
+
456
+ # Head (Pooling and Classifier)
457
+ self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
458
+ if global_pool == 'avg':
459
+ self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
460
+ elif global_pool == 'max':
461
+ self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
462
+ else:
463
+ raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
464
+
465
+ self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
466
+
467
+ self.init_weights(zero_init_last=zero_init_last)
468
+
469
+ @torch.jit.ignore
470
+ def init_weights(self, zero_init_last: bool = True):
471
+ for n, m in self.named_modules():
472
+ if isinstance(m, nn.Conv3d):
473
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
474
+ if zero_init_last:
475
+ for m in self.modules():
476
+ if hasattr(m, 'zero_init_last'):
477
+ m.zero_init_last()
478
+
479
+ @torch.jit.ignore
480
+ def group_matcher(self, coarse: bool = False):
481
+ matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
482
+ return matcher
483
+
484
+ @torch.jit.ignore
485
+ def get_classifier(self, name_only: bool = False):
486
+ return 'fc' if name_only else self.fc
487
+
488
+ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
489
+ self.num_classes = num_classes
490
+ if global_pool == 'avg':
491
+ self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
492
+ elif global_pool == 'max':
493
+ self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
494
+ else:
495
+ raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
496
+
497
+ self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
498
+
499
+ def forward_intermediates(
500
+ self,
501
+ x: torch.Tensor,
502
+ indices: Optional[Union[int, List[int]]] = None,
503
+ norm: bool = False,
504
+ stop_early: bool = False,
505
+ output_fmt: str = 'NCHWD',
506
+ intermediates_only: bool = False,
507
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
508
+ """ Forward features that returns intermediates.
509
+
510
+ Args:
511
+ x: Input image tensor
512
+ indices: Take last n blocks if int, all if None, select matching indices if sequence
513
+ norm: Apply norm layer to compatible intermediates
514
+ stop_early: Stop iterating over blocks when last desired intermediate hit
515
+ output_fmt: Shape of intermediate feature outputs
516
+ intermediates_only: Only return intermediate features
517
+ Returns:
518
+
519
+ """
520
+ assert output_fmt in ('NCHWD',), 'Output shape must be NCHWD.'
521
+ intermediates = []
522
+ take_indices, max_index = feature_take_indices(5, indices)
523
+
524
+ # forward pass
525
+ feat_idx = 0
526
+ x = self.conv1(x)
527
+ x = self.bn1(x)
528
+ x = self.act1(x)
529
+ if feat_idx in take_indices:
530
+ intermediates.append(x)
531
+ x = self.maxpool(x)
532
+
533
+ layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
534
+ if stop_early:
535
+ layer_names = layer_names[:max_index]
536
+ for n in layer_names:
537
+ feat_idx += 1
538
+ x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
539
+ if feat_idx in take_indices:
540
+ intermediates.append(x)
541
+
542
+ if intermediates_only:
543
+ return intermediates
544
+
545
+ return x, intermediates
546
+
547
+ def prune_intermediate_layers(
548
+ self,
549
+ indices: Union[int, List[int]] = 1,
550
+ prune_norm: bool = False,
551
+ prune_head: bool = True,
552
+ ):
553
+ """ Prune layers not required for specified intermediates.
554
+ """
555
+ take_indices, max_index = feature_take_indices(5, indices)
556
+ layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
557
+ layer_names = layer_names[max_index:]
558
+ for n in layer_names:
559
+ setattr(self, n, nn.Identity())
560
+ if prune_head:
561
+ self.reset_classifier(0, '')
562
+ return take_indices
563
+
564
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
565
+ x = self.conv1(x)
566
+ x = self.bn1(x)
567
+ x = self.act1(x)
568
+ x = self.maxpool(x)
569
+
570
+ x = self.layer1(x)
571
+ x = self.layer2(x)
572
+ x = self.layer3(x)
573
+ x = self.layer4(x)
574
+ return x
575
+
576
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
577
+ x = self.global_pool(x)
578
+ x = x.flatten(1)
579
+ if self.drop_rate:
580
+ x = F.dropout(x, p=float(self.drop_rate), training=self.training)
581
+ return x if pre_logits else self.fc(x)
582
+
583
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
584
+ x = self.forward_features(x)
585
+ x = self.forward_head(x)
586
+ return x
587
+
588
+ @classmethod
589
+ def from_pretrained(
590
+ cls,
591
+ checkpoint_path_or_url: Union[str, os.PathLike],
592
+ verbose: bool = True,
593
+ **kwargs
594
+ ) -> 'ResNet':
595
+ """Load pretrained model weights from a local path or a URL."""
596
+ model = cls(**kwargs)
597
+
598
+ def _is_url(path: str) -> bool:
599
+ try:
600
+ parsed = urlparse(str(path))
601
+ return parsed.scheme in ('http', 'https')
602
+ except Exception:
603
+ return False
604
+
605
+ if _is_url(checkpoint_path_or_url):
606
+ if verbose:
607
+ print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
608
+ state_dict = torch.hub.load_state_dict_from_url(
609
+ checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
610
+ else:
611
+ local_path = os.fspath(checkpoint_path_or_url)
612
+ if not os.path.exists(local_path):
613
+ raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
614
+ if verbose:
615
+ print(f"Loading checkpoint from local path: {local_path}")
616
+ state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
617
+
618
+ msg = model.load_state_dict(state_dict, strict=False)
619
+ if verbose:
620
+ print(f"Loaded pretrained weights with msg: {msg}")
621
+ return model
622
+
623
+
624
+
625
+ def resnet18(
626
+ checkpoint_path_or_url: Optional[str] = None,
627
+ **kwargs
628
+ ) -> ResNet:
629
+ """ResNet-18 model with 3D operations.
630
+ """
631
+ kwargs = dict(
632
+ block=BasicBlock,
633
+ layers=[2, 2, 2, 2],
634
+ cardinality=1,
635
+ **kwargs,
636
+ )
637
+ if checkpoint_path_or_url:
638
+ return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
639
+ return ResNet(**kwargs)
640
+
641
+
642
+ def resnet34(
643
+ checkpoint_path_or_url: Optional[str] = None,
644
+ **kwargs
645
+ ) -> ResNet:
646
+ """ResNet-34 model with 3D operations.
647
+ """
648
+ kwargs = dict(
649
+ block=BasicBlock,
650
+ layers=[3, 4, 6, 3],
651
+ cardinality=1,
652
+ **kwargs,
653
+ )
654
+ if checkpoint_path_or_url:
655
+ return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
656
+ return ResNet(**kwargs)
657
+
658
+
659
+ def resnet50(
660
+ checkpoint_path_or_url: Optional[str] = None,
661
+ **kwargs
662
+ ) -> ResNet:
663
+ """ResNet-50 model with 3D operations.
664
+ """
665
+ kwargs = dict(
666
+ block=Bottleneck,
667
+ layers=[3, 4, 6, 3],
668
+ cardinality=1,
669
+ **kwargs,
670
+ )
671
+ if checkpoint_path_or_url:
672
+ return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
673
+ return ResNet(**kwargs)
674
+
675
+
676
+ def resnet101(
677
+ checkpoint_path_or_url: Optional[str] = None,
678
+ **kwargs
679
+ ) -> ResNet:
680
+ """ResNet-101 model with 3D operations.
681
+ """
682
+ kwargs = dict(
683
+ block=Bottleneck,
684
+ layers=[3, 4, 23, 3],
685
+ cardinality=1,
686
+ **kwargs,
687
+ )
688
+ if checkpoint_path_or_url:
689
+ return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
690
+ return ResNet(**kwargs)
691
+
692
+
693
+ def resnext50(
694
+ checkpoint_path_or_url: Optional[str] = None,
695
+ **kwargs
696
+ ) -> ResNet:
697
+ """ResNeXt-50 model with 3D operations.
698
+ """
699
+ kwargs = dict(
700
+ block=Bottleneck,
701
+ layers=[3, 4, 6, 3],
702
+ cardinality=32,
703
+ base_width=4,
704
+ **kwargs,
705
+ )
706
+ if checkpoint_path_or_url:
707
+ return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
708
+ return ResNet(**kwargs)
709
+
710
+
711
+ def resnext101(
712
+ checkpoint_path_or_url: Optional[str] = None,
713
+ **kwargs
714
+ ) -> ResNet:
715
+ """ResNeXt-101 model with 3D operations.
716
+ """
717
+ kwargs = dict(
718
+ block=Bottleneck,
719
+ layers=[3, 4, 23, 3],
720
+ cardinality=32,
721
+ base_width=8,
722
+ **kwargs,
723
+ )
724
+ if checkpoint_path_or_url:
725
+ return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
726
+ return ResNet(**kwargs)
spectre/models/seomt.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/tue-mps/eomt/
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from spectre.models import VisionTransformer
11
+ from spectre.models.layers import LayerNorm3d
12
+
13
+
14
+ class ScaleBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim: int,
18
+ scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
19
+ conv1_layer: nn.Module = nn.ConvTranspose3d,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.conv1 = conv1_layer(
24
+ embed_dim,
25
+ embed_dim,
26
+ kernel_size=scale_factors,
27
+ stride=scale_factors,
28
+ )
29
+ self.act = nn.GELU()
30
+ self.conv2 = nn.Conv3d(
31
+ embed_dim,
32
+ embed_dim,
33
+ kernel_size=3,
34
+ padding=1,
35
+ groups=embed_dim,
36
+ bias=False,
37
+ )
38
+ self.norm = LayerNorm3d(embed_dim)
39
+
40
+ def forward(self, x):
41
+ # print(x.shape)
42
+ x = self.conv1(x)
43
+ x = self.act(x)
44
+ x = self.conv2(x)
45
+ x = self.norm(x)
46
+
47
+ return x
48
+
49
+
50
+ def compute_upscale_stages(patch_size, min_size=4):
51
+ # Compute how many times to upscale per dimension
52
+ num_stages = []
53
+ for size in patch_size:
54
+ stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
55
+ num_stages.append(stages)
56
+ return num_stages
57
+
58
+ def voxel_shuffle_3d(x: torch.Tensor, r: int = 2) -> torch.Tensor:
59
+ """
60
+ Rearranges channels of a 5D tensor (N, C*r^3, D, H, W) to
61
+ (N, C, D*r, H*r, W*r).
62
+ """
63
+ n, c, d, h, w = x.size()
64
+ assert c % (r ** 3) == 0, f"Channels {c} not divisible by r^3={r**3}"
65
+ out_c = c // (r ** 3)
66
+ x = x.view(n, out_c, r, r, r, d, h, w) # (N, C, r, r, r, D, H, W)
67
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() # (N, C, D, r, H, r, W, r)
68
+ x = x.view(n, out_c, d * r, h * r, w * r) # (N, C, D*r, H*r, W*r)
69
+ return x
70
+
71
+
72
+ class MLPUpBlock3D(nn.Module):
73
+ """
74
+ 2x upsampling via:
75
+ - 1x1x1 Conv (per-voxel MLP) expanding channels by 2^3
76
+ - 3D voxel shuffle to double D/H/W
77
+ - optional norm + activation
78
+ """
79
+ def __init__(self, channels: int, norm=nn.Identity, activation=nn.GELU):
80
+ super().__init__()
81
+ self.proj = nn.Conv3d(channels, channels * 8, kernel_size=1, bias=True)
82
+ self.norm = norm(channels) if norm is not None else nn.Identity()
83
+ self.act = activation() if activation is not None else nn.Identity()
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x = self.proj(x) # (N, C*8, D, H, W)
87
+ x = voxel_shuffle_3d(x, 2) # (N, C, 2D, 2H, 2W)
88
+ x = self.norm(x)
89
+ x = self.act(x)
90
+ return x
91
+
92
+
93
+ class SmolMLPDecoder3D(nn.Module):
94
+ """
95
+ Simple decoder with two MLP upsampling layers (2x each) for total 4x upsampling.
96
+ """
97
+ def __init__(
98
+ self,
99
+ in_channels: int,
100
+ out_channels: int,
101
+ norm=LayerNorm3d,
102
+ activation=nn.GELU,
103
+ ):
104
+ super().__init__()
105
+ self.up1 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
106
+ self.up2 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
107
+ self.head = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True)
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ x = self.up1(x) # 2x
111
+ x = self.up2(x) # another 2x => 4x total
112
+ x = self.head(x) # map to desired output channels
113
+ return x
114
+
115
+ class SpatialMLPDecoder3D(nn.Module):
116
+ """
117
+ Per-voxel MLP that expands logits from (B,Q,H,W,D) to (B,Q,H*s,W*s,D*s)
118
+ by predicting a learned s^3 block per voxel, then rearranging spatially.
119
+ """
120
+ def __init__(self, num_classes: int, upscale_factor: int = 4, hidden_mul: int = 4):
121
+ super().__init__()
122
+ self.num_classes = num_classes
123
+ self.s = upscale_factor
124
+ hidden = hidden_mul * num_classes
125
+ self.mlp = nn.Sequential(
126
+ nn.Linear(num_classes, hidden),
127
+ nn.GELU(),
128
+ nn.Linear(hidden, num_classes * (upscale_factor ** 3)),
129
+ )
130
+
131
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
132
+ # x: (B, Q, H, W, D)
133
+ B, Q, H, W, D = x.shape
134
+ s = self.s
135
+ x = x.permute(0, 2, 3, 4, 1).contiguous().view(B * H * W * D, Q) # (BHW D, Q)
136
+ x = self.mlp(x) # (BHW D, Q*s^3)
137
+ x = x.view(B, H, W, D, Q, s, s, s) # (B,H,W,D,Q,s,s,s)
138
+ x = x.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous() # (B,Q,H,s,W,s,D,s)
139
+ x = x.view(B, Q, H * s, W * s, D * s)
140
+ return x
141
+
142
+
143
+ class SimpleConvDecoder3D(nn.Module):
144
+ """
145
+ Depthwise ConvTranspose3d upsampler:
146
+ - Assumes input & output channels == num_classes
147
+ - Uses groups=num_classes for class-wise deconvolution
148
+ """
149
+ def __init__(self, num_classes: int, upscale_factor: int = 4):
150
+ super().__init__()
151
+ s = upscale_factor
152
+ self.deconv = nn.ConvTranspose3d(
153
+ in_channels=num_classes,
154
+ out_channels=num_classes,
155
+ kernel_size=s,
156
+ stride=s,
157
+ padding=0,
158
+ output_padding=0,
159
+ groups=num_classes,
160
+ bias=True,
161
+ )
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ # x: (B, Q, H, W, D)
165
+ return self.deconv(x)
166
+
167
+
168
+ class SEoMT(nn.Module):
169
+ def __init__(
170
+ self,
171
+ backbone: VisionTransformer,
172
+ num_classes: int,
173
+ # num_q: int,
174
+ num_blocks=4,
175
+ masked_attn_enabled=True,
176
+ return_only_final_layer=False,
177
+ upscale_output=True,
178
+ for_nnunet=False,
179
+ decoder=False,
180
+ ):
181
+ super().__init__()
182
+ self.backbone = backbone
183
+ self.num_q = num_classes
184
+ self.num_blocks = num_blocks
185
+ self.masked_attn_enabled = masked_attn_enabled
186
+ self.return_only_final_layer = return_only_final_layer
187
+ self.upscale_output = upscale_output
188
+ self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
189
+ self.for_nnunet = for_nnunet
190
+ self.q = nn.Embedding(num_classes, self.backbone.embed_dim)
191
+
192
+
193
+ self.mask_head = nn.Sequential(
194
+ nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
195
+ nn.GELU(),
196
+ nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
197
+ nn.GELU(),
198
+ nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
199
+ )
200
+
201
+ patch_size = self.backbone.patch_embed.patch_size
202
+ num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
203
+
204
+ # Build per-stage scale factors list
205
+ max_stages = max(num_upscale_stages)
206
+ upscale_blocks = []
207
+ for stage_idx in range(max_stages):
208
+ # for each dimension, upscale by 2 only if this dimension still has
209
+ # remaining upscales at this stage
210
+ scale_factors = tuple(
211
+ 2 if stage_idx < num_upscale_stages[dim] else 1
212
+ for dim in range(len(patch_size))
213
+ )
214
+ upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
215
+ scale_factors=scale_factors))
216
+
217
+ self.upscale = nn.Sequential(*upscale_blocks)
218
+
219
+ def _predict(self, x: torch.Tensor, stage: int = None):
220
+ q = x[:, : self.num_q, :]
221
+ # print(stage)
222
+ # class_logits = self.class_head(q)
223
+ x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
224
+ x = x.transpose(1, 2).reshape(
225
+ x.shape[0], -1, *self.backbone.patch_embed.grid_size
226
+ )
227
+ mask_logits = torch.einsum(
228
+ "bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
229
+ )
230
+
231
+
232
+ return mask_logits
233
+
234
+ @torch.compiler.disable
235
+ def _disable_attn_mask(self, attn_mask, prob):
236
+ if prob < 1:
237
+ random_queries = (
238
+ torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
239
+ > prob
240
+ )
241
+ attn_mask[
242
+ :, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
243
+ ][random_queries] = True
244
+
245
+ return attn_mask
246
+
247
+ def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope = None):
248
+ B, N, C = x.shape
249
+
250
+ q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
251
+ kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
252
+ k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
253
+ q, k = module.q_norm(q), module.k_norm(k)
254
+
255
+ if mask is not None:
256
+ mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
257
+
258
+ dropout_p = module.attn_drop.p if self.training else 0.0
259
+
260
+ if rope is not None:
261
+ if isinstance(rope, list):
262
+ rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
263
+ q, k = module.apply_rotary_pos_emb(q, k, rope)
264
+
265
+ if module.fused_attn:
266
+ x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
267
+ else:
268
+ attn = (q @ k.transpose(-2, -1)) * module.scale
269
+ if mask is not None:
270
+ attn = attn.masked_fill(~mask, float("-inf"))
271
+ attn = F.softmax(attn, dim=-1)
272
+ attn = module.attn_drop(attn)
273
+ x = attn @ v
274
+
275
+ x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
276
+
277
+ return x
278
+
279
+ def forward(self, x: torch.Tensor):
280
+
281
+ if self.for_nnunet: # swap data order, will be incoming at czyx - cxyz
282
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
283
+
284
+ self.backbone.patch_embed.set_input_size(x.shape[2:])
285
+ x = self.backbone.patch_embed(x)
286
+ x, rope = self.backbone._pos_embed(x)
287
+ x = self.backbone.patch_drop(x)
288
+ x = self.backbone.norm_pre(x)
289
+ attn_mask = None
290
+ mask_logits_per_layer = []
291
+
292
+ for i, block in enumerate(self.backbone.blocks):
293
+ if i == len(self.backbone.blocks) - self.num_blocks:
294
+ x = torch.cat(
295
+ (self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
296
+ )
297
+
298
+ if (
299
+ self.masked_attn_enabled
300
+ and i >= len(self.backbone.blocks) - self.num_blocks
301
+ ):
302
+ mask_logits = self._predict(self.backbone.norm(x))
303
+
304
+ if self.for_nnunet:
305
+ # swap back to czyx
306
+
307
+ if self.upscale_output:
308
+ # Upscale to original input size / stage
309
+
310
+ stage = len(self.backbone.blocks) -i
311
+ if stage is not None:
312
+ input_size = tuple(
313
+ int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**(stage))
314
+ for dim in range(len(self.backbone.patch_embed.patch_size))
315
+ )
316
+
317
+ mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
318
+ else:
319
+ mask_logits_per_layer.append(mask_logits.permute(0, 1, 4, 2, 3).contiguous())
320
+ else:
321
+ mask_logits_per_layer.append(mask_logits)
322
+
323
+ attn_mask = torch.ones(
324
+ x.shape[0],
325
+ x.shape[1],
326
+ x.shape[1],
327
+ dtype=torch.bool,
328
+ device=x.device,
329
+ )
330
+ interpolated = F.interpolate(
331
+ mask_logits,
332
+ self.backbone.patch_embed.grid_size,
333
+ mode="trilinear",
334
+ )
335
+ interpolated = interpolated.view(
336
+ interpolated.size(0), interpolated.size(1), -1
337
+ )
338
+ attn_mask[
339
+ :,
340
+ : self.num_q,
341
+ self.num_q + self.backbone.num_prefix_tokens :,
342
+ ] = (
343
+ interpolated > 0
344
+ )
345
+ attn_mask = self._disable_attn_mask(
346
+ attn_mask,
347
+ self.attn_mask_probs[
348
+ i - len(self.backbone.blocks) + self.num_blocks
349
+ ],
350
+ )
351
+ x = x + block.drop_path1(
352
+ block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope=rope))
353
+ )
354
+ x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
355
+
356
+ mask_logits = self._predict(self.backbone.norm(x))
357
+ if self.for_nnunet:
358
+ input_size = tuple(
359
+ int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**0)
360
+ for dim in range(len(self.backbone.patch_embed.patch_size))
361
+ )
362
+ mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
363
+ else:
364
+ mask_logits_per_layer.append(mask_logits)
365
+
366
+ if self.for_nnunet:
367
+ # return in reversed order for deep supervision
368
+ mask_logits_per_layer = mask_logits_per_layer[::-1]
369
+ return mask_logits_per_layer if not self.return_only_final_layer else mask_logits_per_layer[0]
370
+
371
+
372
+ if __name__ == "__main__":
373
+ from spectre.models import vit_large_patch16_128
374
+
375
+ model = SEoMT(
376
+ backbone=vit_large_patch16_128(pos_embed='rope',
377
+ rope_kwargs={
378
+ "base": 1000.0, # works for most 3D models
379
+ },),
380
+ num_classes=4,
381
+ num_blocks=4,
382
+ masked_attn_enabled=True,
383
+ return_only_final_layer=True,
384
+ for_nnunet=True,
385
+ upscale_output=True,
386
+ decoder=False,
387
+ )
388
+ # print number of parameters
389
+ print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
390
+
391
+ x = torch.randn(2, 1, 64, 128, 128)
392
+ out = model(x)
393
+ for o in out:
394
+ print(o.shape)
spectre/models/upsample_anything.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.optim.lr_scheduler import LambdaLR
6
+
7
+
8
+ def UPA(hr_image, lr_volume, device="cuda", use_amp=True):
9
+ """
10
+ hr_image: numpy or torch [C,Hh,Wh,Dh]
11
+ lr_volume: torch [1,C,Hl,Wl,Dl]
12
+ """
13
+
14
+ hr = torch.as_tensor(hr_image).unsqueeze(0).float().to(device)
15
+
16
+ _, _, Hh, Wh, Dh = hr.shape
17
+ _, _, Hl, Wl, Dl = lr_volume.shape
18
+ scale = Hh // Hl
19
+ assert Wh // Wl == scale and Dh // Dl == scale, "Inconsistent scale factors"
20
+
21
+ lr_volume = lr_volume.to(device).float()
22
+ lr = F.interpolate(hr, scale_factor=1/scale, mode="trilinear", align_corners=False)
23
+
24
+ model = LearnablePixelwiseAnisoJBU3D(
25
+ Hl, Wl, Dl, scale=scale
26
+ ).to(device)
27
+
28
+ model.train()
29
+ opt = torch.optim.Adam(model.parameters(), lr=1e-1)
30
+ max_steps = 350
31
+ gamma = (1e-9 / 1e-1) ** (1.0 / max_steps)
32
+ scheduler = LambdaLR(opt, lr_lambda=lambda step: gamma ** step)
33
+ scaler = torch.amp.GradScaler(device=device, enabled=use_amp)
34
+
35
+ for step in range(max_steps):
36
+ opt.zero_grad(set_to_none=True)
37
+ with torch.amp.autocast(device_type=device, enabled=use_amp):
38
+ pred = model(lr, hr)
39
+ loss = F.l1_loss(pred, hr)
40
+
41
+ scaler.scale(loss).backward()
42
+ scaler.step(opt)
43
+ scaler.update()
44
+ scheduler.step()
45
+
46
+ if step % 50 == 0:
47
+ print(f"step {step}: loss={loss.item():.5f}")
48
+
49
+ model.eval()
50
+ with torch.inference_mode(), \
51
+ torch.amp.autocast(device_type=device, enabled=use_amp, dtype=torch.float16):
52
+ out = model(lr_volume, hr)
53
+
54
+ return out
55
+
56
+
57
+ @torch.no_grad()
58
+ def _build_offsets_3d(R_max: int, device):
59
+ offs = torch.arange(-R_max, R_max + 1, device=device)
60
+ dX, dY, dZ = torch.meshgrid(offs, offs, offs, indexing="ij")
61
+ return (
62
+ dX.reshape(-1),
63
+ dY.reshape(-1),
64
+ dZ.reshape(-1),
65
+ ) # [K]
66
+
67
+
68
+ def gather_lr_scalar_3d(map_lr, Ui, Vi, Wi):
69
+ """
70
+ map_lr: [1,1,Hl,Wl,Dl] or [Hl,Wl,Dl]
71
+ Ui,Vi,Wi: [Bn,Hh,Wh,Dh]
72
+ """
73
+ Hl, Wl, Dl = map_lr.shape[-3:]
74
+ flat = Hl * Wl * Dl
75
+ idx = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
76
+ t = map_lr.view(flat)
77
+ vals = t.index_select(0, idx)
78
+ return vals.view(Ui.shape)
79
+
80
+
81
+ def gs_jbu_aniso_noparent_3d(
82
+ feat_lr, # [1,C,Hl,Wl,Dl]
83
+ guide_hr, # [1,G,Hh,Wh,Dh]
84
+ scale,
85
+ sigma_x_map,
86
+ sigma_y_map,
87
+ sigma_z_map,
88
+ sigma_r_map,
89
+ R_max=3,
90
+ alpha_dyn=2.0,
91
+ C_chunk=64,
92
+ Nn_chunk=125,
93
+ ):
94
+ _, C, Hl, Wl, Dl = feat_lr.shape
95
+ _, _, Hh, Wh, Dh = guide_hr.shape
96
+ device = feat_lr.device
97
+ dtype_feat = feat_lr.dtype
98
+
99
+ # HR grid
100
+ x = torch.arange(Hh, device=device, dtype=torch.float32)
101
+ y = torch.arange(Wh, device=device, dtype=torch.float32)
102
+ z = torch.arange(Dh, device=device, dtype=torch.float32)
103
+ X, Y, Z = torch.meshgrid(x, y, z, indexing="ij")
104
+
105
+ u = (X + 0.5) / scale - 0.5
106
+ v = (Y + 0.5) / scale - 0.5
107
+ w = (Z + 0.5) / scale - 0.5
108
+
109
+ uc = torch.round(u).clamp(0, Hl - 1).long()
110
+ vc = torch.round(v).clamp(0, Wl - 1).long()
111
+ wc = torch.round(w).clamp(0, Dl - 1).long()
112
+
113
+ # Dynamic radius
114
+ sigma_eff = torch.maximum(
115
+ sigma_x_map,
116
+ torch.maximum(sigma_y_map, sigma_z_map),
117
+ )
118
+ sigma_eff_hr = F.interpolate(
119
+ sigma_eff, (Hh, Wh, Dh), mode="trilinear", align_corners=False
120
+ )
121
+ # sigma_eff_hr = sigma_eff_hr.squeeze(0).squeeze(0)
122
+ R_map = torch.ceil(alpha_dyn * sigma_eff_hr).clamp(1, R_max).long()
123
+
124
+ dX_all, dY_all, dZ_all = _build_offsets_3d(R_max, device)
125
+
126
+ num = torch.zeros(C, Hh, Wh, Dh, device=device, dtype=torch.float32)
127
+ den = torch.zeros(Hh, Wh, Dh, device=device, dtype=torch.float32)
128
+ m = torch.full((Hh, Wh, Dh), -1e9, device=device, dtype=torch.float32)
129
+
130
+ feat_flat = feat_lr[0].permute(1, 2, 3, 0).reshape(-1, C)
131
+ guide_lr = F.interpolate(
132
+ guide_hr, (Hl, Wl, Dl), mode="trilinear", align_corners=False
133
+ )
134
+
135
+ for n0 in range(0, len(dX_all), Nn_chunk):
136
+ dX = dX_all[n0:n0+Nn_chunk][:, None, None, None]
137
+ dY = dY_all[n0:n0+Nn_chunk][:, None, None, None]
138
+ dZ = dZ_all[n0:n0+Nn_chunk][:, None, None, None]
139
+
140
+ Ui = torch.clamp(uc.unsqueeze(0) + dX, 0, Hl - 1)
141
+ Vi = torch.clamp(vc.unsqueeze(0) + dY, 0, Wl - 1)
142
+ Wi = torch.clamp(wc.unsqueeze(0) + dZ, 0, Dl - 1)
143
+
144
+ # mask = (dX**2 + dY**2 + dZ**2 <= R_map[None, ...] ** 2)
145
+ mask = (dX**2 + dY**2 + dZ**2 <= R_map**2).squeeze(0).squeeze(0)
146
+
147
+ cx = (Ui.float() + 0.5) * scale - 0.5
148
+ cy = (Vi.float() + 0.5) * scale - 0.5
149
+ cz = (Wi.float() + 0.5) * scale - 0.5
150
+
151
+ dx = X.unsqueeze(0) - cx
152
+ dy = Y.unsqueeze(0) - cy
153
+ dz = Z.unsqueeze(0) - cz
154
+
155
+ sx = gather_lr_scalar_3d(sigma_x_map, Ui, Vi, Wi).clamp_min(1e-6)
156
+ sy = gather_lr_scalar_3d(sigma_y_map, Ui, Vi, Wi).clamp_min(1e-6)
157
+ sz = gather_lr_scalar_3d(sigma_z_map, Ui, Vi, Wi).clamp_min(1e-6)
158
+ sr = gather_lr_scalar_3d(sigma_r_map, Ui, Vi, Wi).clamp_min(1e-6)
159
+
160
+ log_ws = (
161
+ -(dx**2)/(2*sx**2)
162
+ -(dy**2)/(2*sy**2)
163
+ -(dz**2)/(2*sz**2)
164
+ )
165
+
166
+ diff2 = 0.0
167
+ for g in range(guide_hr.shape[1]):
168
+ g0 = gather_lr_scalar_3d(guide_lr[0, g], Ui, Vi, Wi)
169
+ diff2 += (guide_hr[0, g] - g0) ** 2
170
+
171
+ log_wr = -diff2 / (2 * sr**2 + 1e-8)
172
+ log_w = torch.where(mask, log_ws + log_wr, -1e9)
173
+
174
+ m_chunk = log_w.max(dim=0).values
175
+ m_new = torch.maximum(m, m_chunk)
176
+
177
+ scale_old = torch.exp(m - m_new)
178
+ num *= scale_old
179
+ den *= scale_old
180
+
181
+ w = torch.exp(log_w - m_new)
182
+ den += w.sum(0)
183
+
184
+ idx_flat = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
185
+
186
+ for c0 in range(0, C, C_chunk):
187
+ c1 = min(c0 + C_chunk, C)
188
+ f = feat_flat.index_select(0, idx_flat)[:, c0:c1]
189
+ f = f.view(w.shape + (c1 - c0,))
190
+ num[c0:c1] += (f * w[..., None]).sum(0).permute(3, 0, 1, 2)
191
+
192
+ m = m_new
193
+
194
+ out = (num / den.clamp_min(1e-8)).unsqueeze(0)
195
+ return out.to(dtype_feat)
196
+
197
+
198
+ class LearnablePixelwiseAnisoJBU3D(nn.Module):
199
+ def __init__(
200
+ self,
201
+ Hl,
202
+ Wl,
203
+ Dl,
204
+ scale,
205
+ init_sigma=1.5,
206
+ init_sigma_r=0.1,
207
+ R_max=3,
208
+ alpha_dyn=2.0,
209
+ ):
210
+ super().__init__()
211
+ self.scale = scale
212
+ self.R_max = R_max
213
+ self.alpha_dyn = alpha_dyn
214
+
215
+ self.sx_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
216
+ self.sy_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
217
+ self.sz_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
218
+ self.sr_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma_r)))
219
+
220
+ def forward(self, feat_lr, guide_hr):
221
+ return gs_jbu_aniso_noparent_3d(
222
+ feat_lr,
223
+ guide_hr,
224
+ self.scale,
225
+ torch.exp(self.sx_raw),
226
+ torch.exp(self.sy_raw),
227
+ torch.exp(self.sz_raw),
228
+ torch.exp(self.sr_raw),
229
+ R_max=self.R_max,
230
+ alpha_dyn=self.alpha_dyn,
231
+ )
232
+
233
+
234
+ if __name__ == "__main__":
235
+ import argparse
236
+
237
+ import numpy as np
238
+ import nibabel as nib
239
+ import monai.transforms as transforms
240
+
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument("--image_path", type=str, required=True)
243
+ parser.add_argument("--mask_path", type=str, required=True)
244
+ parser.add_argument("--device", type=str, default="cuda")
245
+ parser.add_argument("--use_amp", action="store_true")
246
+ args = parser.parse_args()
247
+
248
+ transform = transforms.Compose([
249
+ transforms.LoadImaged(keys=("image", "mask")),
250
+ transforms.EnsureChannelFirstd(keys=("image", "mask"), channel_dim="no_channel"),
251
+ transforms.ScaleIntensityRanged(
252
+ keys=("image",),
253
+ a_min=-150,
254
+ a_max=250,
255
+ b_min=0.0,
256
+ b_max=1.0,
257
+ clip=True,
258
+ ),
259
+ transforms.Orientationd(keys=("image", "mask"), axcodes="RAS"),
260
+ transforms.RandWeightedCropd(
261
+ keys=("image", "mask"),
262
+ w_key="mask",
263
+ spatial_size=(128, 128, 64),
264
+ num_samples=1,
265
+ ),
266
+ transforms.CopyItemsd(keys=("mask"), times=1, names=("mask_low_res")),
267
+ transforms.Resized(keys=("mask_low_res"), spatial_size=(16, 16, 8), mode="nearest", align_corners=False)
268
+ ])
269
+ sample = transform({
270
+ "image": args.image_path,
271
+ "mask": args.mask_path,
272
+ })[0]
273
+
274
+ nib.save(
275
+ nib.Nifti1Image(
276
+ (F.interpolate(sample["mask_low_res"].unsqueeze(0), size=(128, 128, 64), mode="nearest").squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8)),
277
+ affine=np.eye(4),
278
+ ),
279
+ "mask_low_res_upscaled.nii.gz",
280
+ )
281
+
282
+ sample["mask_low_res"] = F.one_hot(
283
+ sample["mask_low_res"].long().squeeze(0), num_classes=4,
284
+ ).permute(3, 0, 1, 2).unsqueeze(0).float()
285
+
286
+ print(sample["mask_low_res"].shape)
287
+
288
+ mask_out = UPA(
289
+ sample["image"],
290
+ sample["mask_low_res"],
291
+ device=args.device,
292
+ use_amp=args.use_amp,
293
+ )
294
+
295
+ mask_out = mask_out.argmax(dim=1, keepdim=True)
296
+
297
+ nib.save(
298
+ nib.Nifti1Image(
299
+ (sample["image"] * 255).squeeze(0).cpu().numpy().astype(np.uint8),
300
+ affine=np.eye(4),
301
+ ),
302
+ "image.nii.gz",
303
+ )
304
+ nib.save(
305
+ nib.Nifti1Image(
306
+ sample["mask"].squeeze(0).cpu().numpy().astype(np.uint8),
307
+ affine=np.eye(4),
308
+ ),
309
+ "mask.nii.gz",
310
+ )
311
+ torch.save(mask_out.squeeze(0).squeeze(0).cpu(), "upsampled_mask.pt")
312
+ nib.save(
313
+ nib.Nifti1Image(
314
+ mask_out.squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8),
315
+ affine=np.eye(4),
316
+ ),
317
+ "upsampled_mask.nii.gz",
318
+ )
319
+
spectre/models/vision_transformer.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+ from urllib.parse import urlparse
4
+ from typing import (
5
+ Tuple, Union, Callable, Literal,
6
+ Optional, Type, Set, List, Dict, Any,
7
+ )
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.layers import PatchDropout, AttentionPoolLatent
12
+ from timm.models.vision_transformer import LayerScale, DropPath, Mlp
13
+ from huggingface_hub import hf_hub_download, load_state_dict_from_file
14
+
15
+ from spectre.models.layers import (
16
+ PatchEmbed,
17
+ Attention,
18
+ RotaryPositionEmbedding,
19
+ )
20
+ from spectre.utils import (
21
+ resample_abs_pos_embed,
22
+ feature_take_indices,
23
+ global_pool_nlc,
24
+ )
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ attn_mode: str = 'mha',
33
+ q_proj_dim: Optional[int] = None,
34
+ kv_proj_dim: Optional[int] = None,
35
+ mlp_ratio: float = 4.,
36
+ qkv_bias: bool = False,
37
+ qk_norm: bool = False,
38
+ proj_bias: bool = True,
39
+ proj_drop: float = 0.,
40
+ attn_drop: float = 0.,
41
+ init_values: Optional[float] = None,
42
+ drop_path: float = 0.,
43
+ act_layer: Type[nn.Module] = nn.GELU,
44
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
45
+ mlp_layer: Type[nn.Module] = Mlp,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.norm1 = norm_layer(dim)
49
+ self.attn = Attention(
50
+ dim,
51
+ num_heads=num_heads,
52
+ mode=attn_mode,
53
+ q_proj_dim=q_proj_dim,
54
+ kv_proj_dim=kv_proj_dim,
55
+ qkv_bias=qkv_bias,
56
+ qk_norm=qk_norm,
57
+ proj_bias=proj_bias,
58
+ attn_drop=attn_drop,
59
+ proj_drop=proj_drop,
60
+ norm_layer=norm_layer,
61
+ )
62
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
63
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
64
+
65
+ self.norm2 = norm_layer(dim)
66
+ self.mlp = mlp_layer(
67
+ in_features=dim,
68
+ hidden_features=int(dim * mlp_ratio),
69
+ act_layer=act_layer,
70
+ bias=proj_bias,
71
+ drop=proj_drop,
72
+ )
73
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
74
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
75
+
76
+ def forward(
77
+ self,
78
+ x: torch.Tensor,
79
+ rope = None
80
+ ) -> torch.Tensor:
81
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rope=rope)))
82
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
83
+ return x
84
+
85
+
86
+ class VisionTransformer(nn.Module):
87
+ """ Vision Transformer with 3D Patch Embedding
88
+ """
89
+ def __init__(
90
+ self,
91
+ img_size: Union[int, Tuple[int, int, int]] = (128, 128, 64),
92
+ patch_size: Union[int, Tuple[int, int, int]] = (16, 16, 8),
93
+ in_chans: int = 1,
94
+ num_classes: int = 1000,
95
+ global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
96
+ embed_dim: int = 768,
97
+ depth: int = 12,
98
+ num_heads: int = 12,
99
+ attn_mode: str = 'mha',
100
+ q_proj_dim: Optional[int] = None,
101
+ kv_proj_dim: Optional[int] = None,
102
+ mlp_ratio: float = 4.,
103
+ qkv_bias: bool = True,
104
+ qk_norm: bool = False,
105
+ proj_bias: bool = True,
106
+ init_values: Optional[float] = None,
107
+ class_token: bool = True,
108
+ pos_embed: str = 'learn',
109
+ no_embed_class: bool = False,
110
+ rope_kwargs: Optional[dict] = None,
111
+ reg_tokens: int = 0,
112
+ pre_norm: bool = False,
113
+ final_norm: bool = True,
114
+ fc_norm: Optional[bool] = None,
115
+ dynamic_img_size: bool = False,
116
+ dynamic_img_pad: bool = False,
117
+ drop_rate: float = 0.,
118
+ pos_drop_rate: float = 0.,
119
+ patch_drop_rate: float = 0.,
120
+ proj_drop_rate: float = 0.,
121
+ attn_drop_rate: float = 0.,
122
+ drop_path_rate: float = 0.,
123
+ embed_layer: Callable = PatchEmbed,
124
+ embed_norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
125
+ norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
126
+ act_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
127
+ block_fn: Type[nn.Module] = Block,
128
+ mlp_layer: Type[nn.Module] = Mlp,
129
+ ) -> None:
130
+ """
131
+ Args:
132
+ img_size: Input image size.
133
+ patch_size: Patch size.
134
+ in_chans: Number of image input channels.
135
+ num_classes: Number of classes for classification head.
136
+ global_pool: Type of global pooling for final sequence (default: 'token').
137
+ embed_dim: Transformer embedding dimension.
138
+ depth: Depth of transformer.
139
+ num_heads: Number of attention heads.
140
+ attn_mode: Attention mode ('mha', 'mqa', 'mla').
141
+ q_proj_dim: Query projection dimension for 'mla' mode.
142
+ kv_proj_dim: Key, value projection dimension for 'mla' mode.
143
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
144
+ qkv_bias: Enable bias for qkv projections if True.
145
+ init_values: Layer-scale init values (layer-scale enabled if not None).
146
+ class_token: Use class token.
147
+ pos_embed: Type of position embedding to use (default: 'learn').
148
+ no_embed_class: Don't include position embeddings for class (or reg) tokens for learnable pos_embed.
149
+ rope_kwargs: Additional arguments for rotary position embedding.
150
+ reg_tokens: Number of register tokens.
151
+ pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
152
+ final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
153
+ fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
154
+ drop_rate: Head dropout rate.
155
+ pos_drop_rate: Position embedding dropout rate.
156
+ attn_drop_rate: Attention dropout rate.
157
+ drop_path_rate: Stochastic depth rate.
158
+ weight_init: Weight initialization scheme.
159
+ fix_init: Apply weight initialization fix (scaling w/ layer index).
160
+ embed_layer: Patch embedding layer.
161
+ embed_norm_layer: Normalization layer to use / override in patch embed module.
162
+ norm_layer: Normalization layer.
163
+ act_layer: MLP activation layer.
164
+ block_fn: Transformer block layer.
165
+ """
166
+ super().__init__()
167
+ assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
168
+ assert class_token or global_pool != 'token'
169
+ assert pos_embed in ('', 'none', 'learn', 'rope')
170
+ assert attn_mode in ('mha', 'mqa', 'mla')
171
+ rope_kwargs = {} if rope_kwargs is None else dict(rope_kwargs)
172
+ rope_kwargs.setdefault("dtype", torch.float32) # robust with mixed-precision
173
+ use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
174
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
175
+ embed_norm_layer = embed_norm_layer
176
+ act_layer = act_layer or nn.GELU
177
+
178
+ self.num_classes = num_classes
179
+ self.global_pool = global_pool
180
+ self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
181
+ self.num_prefix_tokens = 1 if class_token else 0
182
+ self.num_prefix_tokens += reg_tokens
183
+ self.num_reg_tokens = reg_tokens
184
+ self.has_class_token = class_token
185
+ self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
186
+ self.dynamic_img_size = dynamic_img_size
187
+
188
+ embed_args = {}
189
+ if self.dynamic_img_size:
190
+ # flatten deferred until after pos embed
191
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWDC"))
192
+ elif pos_embed == 'rope':
193
+ embed_args['output_fmt'] = "NHWDC"
194
+ if embed_norm_layer is not None:
195
+ embed_args['norm_layer'] = embed_norm_layer
196
+ self.patch_embed = embed_layer(
197
+ img_size=img_size,
198
+ patch_size=patch_size,
199
+ in_chans=in_chans,
200
+ embed_dim=embed_dim,
201
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
202
+ dynamic_img_pad=dynamic_img_pad,
203
+ **embed_args,
204
+ )
205
+ num_patches = self.patch_embed.num_patches
206
+ reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
207
+
208
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
209
+ self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
210
+ embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
211
+ self.pos_embed, self.rope, self.requires_per_sample_rope = None, None, False
212
+ if pos_embed == 'learn':
213
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
214
+ if pos_embed == 'rope':
215
+ self.rope = RotaryPositionEmbedding(
216
+ embed_dim=embed_dim,
217
+ num_heads=num_heads,
218
+ **rope_kwargs,
219
+ )
220
+ self.requires_per_sample_rope = any([
221
+ self.rope.shift_coords is not None,
222
+ self.rope.jitter_coords is not None,
223
+ self.rope.rescale_coords is not None,
224
+ ])
225
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
226
+ if patch_drop_rate > 0:
227
+ self.patch_drop = PatchDropout(
228
+ patch_drop_rate,
229
+ num_prefix_tokens=self.num_prefix_tokens,
230
+ )
231
+ else:
232
+ self.patch_drop = nn.Identity()
233
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
234
+
235
+ dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule
236
+ self.blocks = nn.Sequential(*[
237
+ block_fn(
238
+ dim=embed_dim,
239
+ num_heads=num_heads,
240
+ attn_mode=attn_mode,
241
+ q_proj_dim=q_proj_dim,
242
+ kv_proj_dim=kv_proj_dim,
243
+ mlp_ratio=mlp_ratio,
244
+ qkv_bias=qkv_bias,
245
+ qk_norm=qk_norm,
246
+ proj_bias=proj_bias,
247
+ init_values=init_values,
248
+ proj_drop=proj_drop_rate,
249
+ attn_drop=attn_drop_rate,
250
+ drop_path=dpr[i],
251
+ norm_layer=norm_layer,
252
+ act_layer=act_layer,
253
+ mlp_layer=mlp_layer,
254
+ )
255
+ for i in range(depth)])
256
+ self.feature_info = [
257
+ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
258
+ self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
259
+
260
+ # Classifier Head
261
+ if global_pool == 'map':
262
+ self.attn_pool = AttentionPoolLatent(
263
+ self.embed_dim,
264
+ num_heads=num_heads,
265
+ mlp_ratio=mlp_ratio,
266
+ norm_layer=norm_layer,
267
+ act_layer=act_layer,
268
+ )
269
+ else:
270
+ self.attn_pool = None
271
+ self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
272
+ self.head_drop = nn.Dropout(drop_rate)
273
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
274
+
275
+ self.init_weights()
276
+
277
+ def init_weights(self) -> None:
278
+ if self.pos_embed is not None and not self.pos_embed.is_meta:
279
+ nn.init.trunc_normal_(self.pos_embed, std=.02)
280
+ if self.cls_token is not None and not self.cls_token.is_meta:
281
+ nn.init.normal_(self.cls_token, std=1e-6)
282
+ if self.reg_token is not None and not self.reg_token.is_meta:
283
+ nn.init.normal_(self.reg_token, std=1e-6)
284
+ self.apply(self._init_weights)
285
+
286
+ def _init_weights(self, m: nn.Module) -> None:
287
+ # this fn left here for compat with downstream users
288
+ if isinstance(m, nn.Linear):
289
+ if not m.weight.is_meta:
290
+ nn.init.trunc_normal_(m.weight, std=.02)
291
+ if m.bias is not None and not m.bias.is_meta:
292
+ nn.init.zeros_(m.bias)
293
+
294
+ @torch.jit.ignore
295
+ def no_weight_decay(self) -> Set:
296
+ return {'pos_embed', 'cls_token', 'dist_token'}
297
+
298
+ @torch.jit.ignore
299
+ def get_classifier(self) -> nn.Module:
300
+ return self.head
301
+
302
+ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
303
+ self.num_classes = num_classes
304
+ if global_pool is not None:
305
+ assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
306
+ if global_pool == 'map' and self.attn_pool is None:
307
+ assert False, "Cannot currently add attention pooling in reset_classifier()."
308
+ elif global_pool != 'map' and self.attn_pool is not None:
309
+ self.attn_pool = None # remove attention pooling
310
+ self.global_pool = global_pool
311
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
312
+
313
+ def set_input_size(
314
+ self,
315
+ img_size: Optional[Tuple[int, int, int]] = None,
316
+ patch_size: Optional[Tuple[int, int, int]] = None,
317
+ ):
318
+ """Method updates the input image resolution, patch size
319
+
320
+ Args:
321
+ img_size: New input resolution, if None current resolution is used
322
+ patch_size: New patch size, if None existing patch size is used
323
+ """
324
+ prev_grid_size = self.patch_embed.grid_size
325
+ self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
326
+ if self.pos_embed is not None:
327
+ num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
328
+ num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
329
+ if num_new_tokens != self.pos_embed.shape[1]:
330
+ self.pos_embed = nn.Parameter(resample_abs_pos_embed(
331
+ self.pos_embed,
332
+ new_size=self.patch_embed.grid_size,
333
+ old_size=prev_grid_size,
334
+ num_prefix_tokens=num_prefix_tokens,
335
+ verbose=True,
336
+ ))
337
+
338
+ def _pos_embed(self, x: torch.Tensor):
339
+ if self.pos_embed is None and self.rope is None:
340
+ x = x.view(x.shape[0], -1, x.shape[-1])
341
+ if self.reg_token is not None:
342
+ x = torch.cat([self.reg_token.expand(x.shape[0], -1, -1), x], dim=1)
343
+ if self.cls_token is not None:
344
+ x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
345
+ return x, None
346
+
347
+ if self.dynamic_img_size or self.rope is not None:
348
+ B, H, W, D, C = x.shape
349
+ x = x.view(B, -1, C)
350
+
351
+ pos_embed, rope = None, None
352
+ if self.pos_embed is not None:
353
+ if self.dynamic_img_size:
354
+ prev_grid_size = self.patch_embed.grid_size
355
+ pos_embed = resample_abs_pos_embed(
356
+ self.pos_embed,
357
+ new_size=(H, W, D),
358
+ old_size=prev_grid_size,
359
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
360
+ )
361
+ else:
362
+ pos_embed = self.pos_embed
363
+
364
+ if self.rope is not None:
365
+ if self.requires_per_sample_rope:
366
+ rope = [self.rope(H=H, W=W, D=D) for _ in range(B)]
367
+ else:
368
+ rope = self.rope(H=H, W=W, D=D)
369
+
370
+ to_cat = []
371
+ if self.cls_token is not None:
372
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
373
+ if self.reg_token is not None:
374
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
375
+
376
+ if self.no_embed_class:
377
+ # deit-3, updated JAX (big vision)
378
+ # position embedding does not overlap with class token, add then concat
379
+ if pos_embed is not None:
380
+ x = x + pos_embed
381
+ if to_cat:
382
+ x = torch.cat(to_cat + [x], dim=1)
383
+ else:
384
+ # original timm, JAX, and deit vit impl
385
+ # pos_embed has entry for class token, concat then add
386
+ if to_cat:
387
+ x = torch.cat(to_cat + [x], dim=1)
388
+ if pos_embed is not None:
389
+ x = x + pos_embed
390
+
391
+ return self.pos_drop(x), rope
392
+
393
+ def forward_intermediates(
394
+ self,
395
+ x: torch.Tensor,
396
+ indices: Optional[Union[int, List[int]]] = None,
397
+ return_prefix_tokens: bool = False,
398
+ norm: bool = False,
399
+ stop_early: bool = False,
400
+ output_fmt: str = 'NCHWD',
401
+ intermediates_only: bool = False,
402
+ output_dict: bool = False,
403
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
404
+ """ Forward features that returns intermediates.
405
+
406
+ Args:
407
+ x: Input image tensor
408
+ indices: Take last n blocks if int, all if None, select matching indices if sequence
409
+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
410
+ norm: Apply norm layer to all intermediates
411
+ stop_early: Stop iterating over blocks when last desired intermediate hit
412
+ output_fmt: Shape of intermediate feature outputs
413
+ intermediates_only: Only return intermediate features
414
+ output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
415
+ Returns:
416
+ A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
417
+ 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
418
+
419
+ """
420
+ assert output_fmt in ('NCHWD', 'NLC'), 'Output format must be one of NCHWD or NLC.'
421
+ reshape = output_fmt == 'NCHWD'
422
+ intermediates = []
423
+ take_indices, max_index = feature_take_indices(len(self.blocks), indices)
424
+
425
+ # forward pass
426
+ B, _, height, width, depth = x.shape
427
+ x = self.patch_embed(x)
428
+ x, rope = self._pos_embed(x)
429
+ x = self.patch_drop(x)
430
+ x = self.norm_pre(x)
431
+
432
+ if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
433
+ blocks = self.blocks
434
+ else:
435
+ blocks = self.blocks[:max_index + 1]
436
+ for i, blk in enumerate(blocks):
437
+ x = blk(x, rope=rope)
438
+ if i in take_indices:
439
+ # normalize intermediates with final norm layer if enabled
440
+ intermediates.append(self.norm(x) if norm else x)
441
+
442
+ # process intermediates
443
+ if self.num_prefix_tokens:
444
+ # split prefix (e.g. class, distill) and spatial feature tokens
445
+ prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
446
+ intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
447
+ else:
448
+ prefix_tokens = None
449
+
450
+ if reshape:
451
+ # reshape to BCHW output format
452
+ H, W, D = self.patch_embed.dynamic_feat_size((height, width, depth))
453
+ intermediates = [y.reshape(B, H, W, D, -1).permute(0, 4, 1, 2, 3).contiguous() for y in intermediates]
454
+
455
+ if output_dict:
456
+ result_dict = {}
457
+ # Intermediates are always included
458
+ result_dict['image_intermediates'] = intermediates
459
+ if prefix_tokens is not None and return_prefix_tokens:
460
+ result_dict['image_intermediates_prefix'] = prefix_tokens
461
+
462
+ # Only include features if not intermediates_only
463
+ if not intermediates_only:
464
+ x_final = self.norm(x)
465
+ result_dict['image_features'] = x_final
466
+
467
+ return result_dict
468
+
469
+ # For non-dictionary output, maintain the original behavior
470
+ if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
471
+ # return_prefix not support in torchscript due to poor type handling
472
+ intermediates = list(zip(intermediates, prefix_tokens))
473
+
474
+ if intermediates_only:
475
+ return intermediates
476
+
477
+ x = self.norm(x)
478
+
479
+ return x, intermediates
480
+
481
+ def prune_intermediate_layers(
482
+ self,
483
+ indices: Union[int, List[int]] = 1,
484
+ prune_norm: bool = False,
485
+ prune_head: bool = True,
486
+ ):
487
+ """Prune layers not required for specified intermediates.
488
+
489
+ Args:
490
+ indices: Indices of intermediate layers to keep.
491
+ prune_norm: Whether to prune normalization layer.
492
+ prune_head: Whether to prune the classifier head.
493
+
494
+ Returns:
495
+ List of indices that were kept.
496
+ """
497
+ take_indices, max_index = feature_take_indices(len(self.blocks), indices)
498
+ self.blocks = self.blocks[:max_index + 1] # truncate blocks
499
+ if prune_norm:
500
+ self.norm = nn.Identity()
501
+ if prune_head:
502
+ self.fc_norm = nn.Identity()
503
+ self.reset_classifier(0, '')
504
+ return take_indices
505
+
506
+ def get_intermediate_layers(
507
+ self,
508
+ x: torch.Tensor,
509
+ n: Union[int, List[int], Tuple[int]] = 1,
510
+ reshape: bool = False,
511
+ return_prefix_tokens: bool = False,
512
+ norm: bool = False,
513
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
514
+ """Get intermediate layer outputs (DINO interface compatibility).
515
+
516
+ NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
517
+
518
+ Args:
519
+ x: Input tensor.
520
+ n: Number or indices of layers.
521
+ reshape: Reshape to NCHWD format.
522
+ return_prefix_tokens: Return prefix tokens.
523
+ norm: Apply normalization.
524
+
525
+ Returns:
526
+ List of intermediate features.
527
+ """
528
+ return self.forward_intermediates(
529
+ x, n,
530
+ return_prefix_tokens=return_prefix_tokens,
531
+ norm=norm,
532
+ output_fmt='NCHWD' if reshape else 'NLC',
533
+ intermediates_only=True,
534
+ )
535
+
536
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
537
+ """Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm)."""
538
+ x = self.patch_embed(x)
539
+ x, rope = self._pos_embed(x)
540
+ x = self.patch_drop(x)
541
+ x = self.norm_pre(x)
542
+
543
+ for blk in self.blocks:
544
+ x = blk(x, rope=rope)
545
+ x = self.norm(x)
546
+ return x
547
+
548
+ def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
549
+ """Apply pooling to feature tokens.
550
+
551
+ Args:
552
+ x: Feature tensor.
553
+ pool_type: Pooling type override.
554
+
555
+ Returns:
556
+ Pooled features.
557
+ """
558
+ if self.attn_pool is not None:
559
+ x = self.attn_pool(x)
560
+ return x
561
+ pool_type = self.global_pool if pool_type is None else pool_type
562
+ x = global_pool_nlc(
563
+ x,
564
+ pool_type=pool_type,
565
+ num_prefix_tokens=self.num_prefix_tokens,
566
+ )
567
+ return x
568
+
569
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
570
+ """Forward pass through classifier head.
571
+
572
+ Args:
573
+ x: Feature tensor.
574
+ pre_logits: Return features before final classifier.
575
+
576
+ Returns:
577
+ Output tensor.
578
+ """
579
+ x = self.pool(x)
580
+ x = self.fc_norm(x)
581
+ x = self.head_drop(x)
582
+ return x if pre_logits else self.head(x)
583
+
584
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
585
+ x = self.forward_features(x)
586
+ x = self.forward_head(x)
587
+ return x
588
+
589
+ @classmethod
590
+ def from_pretrained(
591
+ cls,
592
+ checkpoint_path_or_url: Union[str, os.PathLike],
593
+ verbose: bool = True,
594
+ **kwargs
595
+ ) -> 'VisionTransformer':
596
+ """Load pretrained model weights from a local path or a URL."""
597
+ model = cls(**kwargs)
598
+
599
+ def _is_url(path: str) -> bool:
600
+ try:
601
+ parsed = urlparse(str(path))
602
+ return parsed.scheme in ('http', 'https')
603
+ except Exception:
604
+ return False
605
+
606
+ def _is_hf_url(path: str) -> bool:
607
+ try:
608
+ parsed = urlparse(str(path))
609
+ return 'huggingface.co' in parsed.netloc
610
+ except Exception:
611
+ return False
612
+
613
+ if _is_hf_url(checkpoint_path_or_url):
614
+ if verbose:
615
+ print(f"Downloading pretrained weights from Hugging Face URL: {checkpoint_path_or_url}")
616
+ # Extract repo_id and filename from the URL
617
+ parsed = urlparse(checkpoint_path_or_url)
618
+ parts = parsed.path.strip('/').split('/')
619
+ repo_id = '/'.join(parts[:2]) # e.g., 'cclaess/SPECTRE'
620
+ filename = parts[-1] # e.g., 'spectre_backbone_vit_large_patch16_128.pt'
621
+
622
+ local_path = hf_hub_download(repo_id=repo_id, filename=filename)
623
+ state_dict = load_state_dict_from_file(local_path, map_location='cpu')
624
+ elif _is_url(checkpoint_path_or_url):
625
+ if verbose:
626
+ print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
627
+ state_dict = torch.hub.load_state_dict_from_url(
628
+ checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
629
+ else:
630
+ local_path = os.fspath(checkpoint_path_or_url)
631
+ if not os.path.exists(local_path):
632
+ raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
633
+ if verbose:
634
+ print(f"Loading checkpoint from local path: {local_path}")
635
+ state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
636
+
637
+ msg = model.load_state_dict(state_dict, strict=False)
638
+ if verbose:
639
+ print(f"Loaded pretrained weights with msg: {msg}")
640
+ return model
641
+
642
+
643
+ def vit_tiny_patch16_128(
644
+ checkpoint_path_or_url: Optional[str] = None,
645
+ **kwargs
646
+ ) -> VisionTransformer:
647
+ """ViT-Tiny model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
648
+ """
649
+ kwargs = dict(
650
+ img_size=(128, 128, 64),
651
+ patch_size=(16, 16, 8),
652
+ embed_dim=192,
653
+ depth=12,
654
+ num_heads=2,
655
+ mlp_ratio=4,
656
+ qkv_bias=True,
657
+ norm_layer=nn.LayerNorm,
658
+ **kwargs,
659
+ )
660
+ if checkpoint_path_or_url is not None:
661
+ return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
662
+ return VisionTransformer(**kwargs)
663
+
664
+
665
+ def vit_small_patch16_128(
666
+ checkpoint_path_or_url: Optional[str] = None,
667
+ **kwargs
668
+ ) -> VisionTransformer:
669
+ """ViT-Small model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
670
+ """
671
+ kwargs = dict(
672
+ img_size=(128, 128, 64),
673
+ patch_size=(16, 16, 8),
674
+ embed_dim=384,
675
+ depth=12,
676
+ num_heads=4,
677
+ mlp_ratio=4,
678
+ qkv_bias=True,
679
+ norm_layer=nn.LayerNorm,
680
+ **kwargs,
681
+ )
682
+ if checkpoint_path_or_url is not None:
683
+ return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
684
+ return VisionTransformer(**kwargs)
685
+
686
+
687
+ def vit_base_patch16_128(
688
+ checkpoint_path_or_url: Optional[str] = None,
689
+ **kwargs
690
+ ) -> VisionTransformer:
691
+ """ViT-Base model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
692
+ """
693
+ kwargs = dict(
694
+ img_size=(128, 128, 64),
695
+ patch_size=(16, 16, 8),
696
+ embed_dim=768,
697
+ depth=12,
698
+ num_heads=8,
699
+ mlp_ratio=4,
700
+ qkv_bias=True,
701
+ norm_layer=nn.LayerNorm,
702
+ **kwargs,
703
+ )
704
+ if checkpoint_path_or_url is not None:
705
+ return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
706
+ return VisionTransformer(**kwargs)
707
+
708
+ def vit_base_patch16_256(
709
+ pretrained_weights: Optional[str] = None,
710
+ **kwargs
711
+ ) -> VisionTransformer:
712
+ """ViT-Base model with 3D patch embedding, patch size [16, 16, 8] and input size [256, 256, 128].
713
+ """
714
+ kwargs = dict(
715
+ img_size=(256, 256, 128),
716
+ patch_size=(16, 16, 8),
717
+ embed_dim=768,
718
+ depth=12,
719
+ num_heads=8,
720
+ mlp_ratio=4,
721
+ qkv_bias=True,
722
+ norm_layer=nn.LayerNorm,
723
+ **kwargs,
724
+ )
725
+ if pretrained_weights is not None:
726
+ return VisionTransformer.from_pretrained(pretrained_weights, **kwargs)
727
+ return VisionTransformer(**kwargs)
728
+
729
+
730
+ def vit_base_patch32_128(
731
+ checkpoint_path_or_url: Optional[str] = None,
732
+ **kwargs
733
+ ) -> VisionTransformer:
734
+ """ViT-Base model with 3D patch embedding, patch size [32, 32, 16] and input size [128, 128, 64].
735
+ """
736
+ kwargs = dict(
737
+ img_size=(128, 128, 64),
738
+ patch_size=(32, 32, 16),
739
+ embed_dim=768,
740
+ depth=12,
741
+ num_heads=8,
742
+ mlp_ratio=4,
743
+ qkv_bias=True,
744
+ norm_layer=nn.LayerNorm,
745
+ **kwargs,
746
+ )
747
+ if checkpoint_path_or_url is not None:
748
+ return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
749
+ return VisionTransformer(**kwargs)
750
+
751
+
752
+ def vit_large_patch16_128(
753
+ checkpoint_path_or_url: Optional[str] = None,
754
+ **kwargs
755
+ ) -> VisionTransformer:
756
+ """ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
757
+ """
758
+ kwargs = dict(
759
+ img_size=(128, 128, 64),
760
+ patch_size=(16, 16, 8),
761
+ embed_dim=1080,
762
+ depth=24,
763
+ num_heads=12,
764
+ mlp_ratio=4,
765
+ qkv_bias=True,
766
+ norm_layer=nn.LayerNorm,
767
+ **kwargs,
768
+ )
769
+ if checkpoint_path_or_url is not None:
770
+ return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
771
+ return VisionTransformer(**kwargs)
772
+
773
+ def vit_large_patch16_256(
774
+ pretrained_weights: Optional[str] = None,
775
+ **kwargs
776
+ ) -> VisionTransformer:
777
+ """ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64].
778
+ """
779
+ kwargs = dict(
780
+ img_size=(256, 256, 128),
781
+ patch_size=(16, 16, 8),
782
+ embed_dim=1080,
783
+ depth=24,
784
+ num_heads=12,
785
+ mlp_ratio=4,
786
+ qkv_bias=True,
787
+ norm_layer=nn.LayerNorm,
788
+ **kwargs,
789
+ )
790
+ if pretrained_weights is not None:
791
+ return VisionTransformer.from_pretrained(pretrained_weights, **kwargs)
792
+ return VisionTransformer(**kwargs)
793
+
794
+ def vit_large_patch16_320(
795
+ pretrained_weights: Optional[str] = None,
796
+ **kwargs
797
+ ) -> VisionTransformer:
798
+ """ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [320, 320, 128].
799
+ """
800
+ kwargs = dict(
801
+ img_size=(320, 320, 128),
802
+ patch_size=(16, 16, 8),
803
+ embed_dim=1080,
804
+ depth=24,
805
+ num_heads=12,
806
+ mlp_ratio=4,
807
+ qkv_bias=True,
808
+ norm_layer=nn.LayerNorm,
809
+ **kwargs,
810
+ )
811
+ if pretrained_weights is not None:
812
+ return VisionTransformer.from_pretrained(pretrained_weights, **kwargs)
813
+ return VisionTransformer(**kwargs)
814
+
815
+
816
+ def vit_large_patch32_128(
817
+ checkpoint_path_or_url: Optional[str] = None,
818
+ **kwargs
819
+ ) -> VisionTransformer:
820
+ """ViT-Large model with 3D patch embedding, patch size [32, 32, 16] and input size [128, 128, 64].
821
+ """
822
+ kwargs = dict(
823
+ img_size=(128, 128, 64),
824
+ patch_size=(32, 32, 16),
825
+ embed_dim=1080,
826
+ depth=24,
827
+ num_heads=12,
828
+ mlp_ratio=4,
829
+ qkv_bias=True,
830
+ norm_layer=nn.LayerNorm,
831
+ **kwargs,
832
+ )
833
+ if checkpoint_path_or_url is not None:
834
+ return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
835
+ return VisionTransformer(**kwargs)
spectre/models/vision_transformer_features.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from functools import partial
4
+ from urllib.parse import urlparse
5
+ from typing import Union, Callable, Literal, Optional, Type, Set, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from timm.models.vision_transformer import Mlp
10
+ from timm.layers import PatchDropout, AttentionPoolLatent
11
+ from huggingface_hub import hf_hub_download, load_state_dict_from_file
12
+
13
+ from spectre.utils import global_pool_nlc, to_3tuple, resample_abs_pos_embed
14
+ from spectre.models.vision_transformer import Block
15
+ from spectre.models.layers import RotaryPositionEmbedding
16
+
17
+
18
+ class FeatureVisionTransformer(nn.Module):
19
+ """ Vision Transformer that accepts flattened patches as input.
20
+ """
21
+ def __init__(
22
+ self,
23
+ grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
24
+ patch_dim: int = 768,
25
+ num_classes: int = 1000,
26
+ global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
27
+ embed_dim: int = 768,
28
+ depth: int = 12,
29
+ num_heads: int = 12,
30
+ attn_mode: str = 'mha',
31
+ q_proj_dim: Optional[int] = None,
32
+ kv_proj_dim: Optional[int] = None,
33
+ mlp_ratio: float = 4.,
34
+ qkv_bias: bool = True,
35
+ qk_norm: bool = False,
36
+ proj_bias: bool = True,
37
+ init_values: Optional[float] = None,
38
+ class_token: bool = True,
39
+ pos_embed: str = 'learn',
40
+ no_embed_class: bool = False,
41
+ rope_kwargs: Optional[dict] = None,
42
+ reg_tokens: int = 0,
43
+ pre_norm: bool = False,
44
+ final_norm: bool = True,
45
+ fc_norm: Optional[bool] = None,
46
+ dynamic_grid_size: bool = False,
47
+ drop_rate: float = 0.,
48
+ pos_drop_rate: float = 0.,
49
+ patch_drop_rate: float = 0.,
50
+ proj_drop_rate: float = 0.,
51
+ attn_drop_rate: float = 0.,
52
+ drop_path_rate: float = 0.,
53
+ norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
54
+ act_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None,
55
+ block_fn: Type[nn.Module] = Block,
56
+ mlp_layer: Type[nn.Module] = Mlp,
57
+ ) -> None:
58
+ """
59
+ Args:
60
+ num_patches: Number of patches in the input.
61
+ patch_dim: Dimension of each flattened input patch.
62
+ num_classes: Number of classes for classification head.
63
+ global_pool: Type of global pooling for final sequence (default: 'token').
64
+ embed_dim: Transformer embedding dimension.
65
+ depth: Depth of transformer.
66
+ num_heads: Number of attention heads.
67
+ attn_mode: Attention mode ('mha', 'mqa', 'mla').
68
+ q_proj_dim: Query projection dimension for 'mla' mode.
69
+ kv_proj_dim: Key, value projection dimension for 'mla' mode.
70
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
71
+ qkv_bias: Enable bias for qkv projections if True.
72
+ init_values: Layer-scale init values (layer-scale enabled if not None).
73
+ class_token: Use class token.
74
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
75
+ reg_tokens: Number of register tokens.
76
+ pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
77
+ final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
78
+ fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
79
+ drop_rate: Head dropout rate.
80
+ pos_drop_rate: Position embedding dropout rate.
81
+ attn_drop_rate: Attention dropout rate.
82
+ drop_path_rate: Stochastic depth rate.
83
+ weight_init: Weight initialization scheme.
84
+ fix_init: Apply weight initialization fix (scaling w/ layer index).
85
+ norm_layer: Normalization layer.
86
+ act_layer: MLP activation layer.
87
+ block_fn: Transformer block layer.
88
+ """
89
+ super().__init__()
90
+ assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
91
+ assert class_token or global_pool != 'token'
92
+ assert pos_embed in ('', 'none', 'learn', 'rope')
93
+ assert attn_mode in ('mha', 'mqa', 'mla')
94
+ assert grid_size is not None or pos_embed in ('', 'none', 'rope')
95
+ rope_kwargs = {} if rope_kwargs is None else dict(rope_kwargs)
96
+ rope_kwargs.setdefault("dtype", torch.float32) # robust with mixed-precision
97
+ use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
98
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
99
+ act_layer = act_layer or nn.GELU
100
+
101
+ self.grid_size = None if grid_size is None else to_3tuple(grid_size)
102
+ self.num_classes = num_classes
103
+ self.global_pool = global_pool
104
+ self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
105
+ self.num_prefix_tokens = 1 if class_token else 0
106
+ self.num_prefix_tokens += reg_tokens
107
+ self.num_reg_tokens = reg_tokens
108
+ self.has_class_token = class_token
109
+ self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
110
+ self.dynamic_grid_size = dynamic_grid_size
111
+
112
+ self.num_patches = None if grid_size is None else int(math.prod(grid_size))
113
+ self.patch_proj = nn.Linear(patch_dim, embed_dim, proj_bias)
114
+
115
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
116
+ self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
117
+ self.pos_embed, self.rope, self.requires_per_sample_rope = None, None, False
118
+ if pos_embed == 'learn':
119
+ embed_len = self.num_patches if no_embed_class else self.num_patches + self.num_prefix_tokens
120
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
121
+ if pos_embed == 'rope':
122
+ self.rope = RotaryPositionEmbedding(
123
+ embed_dim=embed_dim,
124
+ num_heads=num_heads,
125
+ **rope_kwargs,
126
+ )
127
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
128
+ if patch_drop_rate > 0:
129
+ self.patch_drop = PatchDropout(
130
+ patch_drop_rate,
131
+ num_prefix_tokens=self.num_prefix_tokens,
132
+ )
133
+ else:
134
+ self.patch_drop = nn.Identity()
135
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
136
+
137
+ dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule
138
+ self.blocks = nn.Sequential(*[
139
+ block_fn(
140
+ dim=embed_dim,
141
+ num_heads=num_heads,
142
+ attn_mode=attn_mode,
143
+ q_proj_dim=q_proj_dim,
144
+ kv_proj_dim=kv_proj_dim,
145
+ mlp_ratio=mlp_ratio,
146
+ qkv_bias=qkv_bias,
147
+ qk_norm=qk_norm,
148
+ proj_bias=proj_bias,
149
+ init_values=init_values,
150
+ proj_drop=proj_drop_rate,
151
+ attn_drop=attn_drop_rate,
152
+ drop_path=dpr[i],
153
+ norm_layer=norm_layer,
154
+ act_layer=act_layer,
155
+ mlp_layer=mlp_layer,
156
+ )
157
+ for i in range(depth)])
158
+ self.feature_info = [
159
+ dict(module=f'blocks.{i}', num_chs=embed_dim) for i in range(depth)]
160
+ self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()
161
+
162
+ # Classifier Head
163
+ if global_pool == 'map':
164
+ self.attn_pool = AttentionPoolLatent(
165
+ self.embed_dim,
166
+ num_heads=num_heads,
167
+ mlp_ratio=mlp_ratio,
168
+ norm_layer=norm_layer,
169
+ act_layer=act_layer,
170
+ )
171
+ else:
172
+ self.attn_pool = None
173
+ self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
174
+ self.head_drop = nn.Dropout(drop_rate)
175
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
176
+
177
+ self.init_weights()
178
+
179
+ def init_weights(self) -> None:
180
+ if self.pos_embed is not None and not self.pos_embed.is_meta:
181
+ nn.init.trunc_normal_(self.pos_embed, std=.02)
182
+ if self.cls_token is not None and not self.cls_token.is_meta:
183
+ nn.init.normal_(self.cls_token, std=1e-6)
184
+ if self.reg_token is not None and not self.reg_token.is_meta:
185
+ nn.init.normal_(self.reg_token, std=1e-6)
186
+ self.apply(self._init_weights)
187
+
188
+ def _init_weights(self, m: nn.Module) -> None:
189
+ # this fn left here for compat with downstream users
190
+ if isinstance(m, nn.Linear):
191
+ if not m.weight.is_meta:
192
+ nn.init.trunc_normal_(m.weight, std=.02)
193
+ if m.bias is not None and not m.bias.is_meta:
194
+ nn.init.zeros_(m.bias)
195
+
196
+ @torch.jit.ignore
197
+ def no_weight_decay(self) -> Set:
198
+ return {'pos_embed', 'cls_token', 'dist_token'}
199
+
200
+ @torch.jit.ignore
201
+ def get_classifier(self) -> nn.Module:
202
+ return self.head
203
+
204
+ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
205
+ self.num_classes = num_classes
206
+ if global_pool is not None:
207
+ assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
208
+ if global_pool == 'map' and self.attn_pool is None:
209
+ assert False, "Cannot currently add attention pooling in reset_classifier()."
210
+ elif global_pool != 'map' and self.attn_pool is not None:
211
+ self.attn_pool = None # remove attention pooling
212
+ self.global_pool = global_pool
213
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
214
+
215
+ def _pos_embed(
216
+ self,
217
+ x: torch.Tensor,
218
+ grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
219
+ ):
220
+ if self.pos_embed is None and self.rope is None:
221
+ x = x.view(x.shape[0], -1, x.shape[-1])
222
+ if self.reg_token is not None:
223
+ x = torch.cat([self.reg_token.expand(x.shape[0], -1, -1), x], dim=1)
224
+ if self.cls_token is not None:
225
+ x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
226
+ return x, None
227
+
228
+ if self.dynamic_grid_size or self.rope is not None:
229
+ assert grid_size is not None, "grid_size must be provided when using dynamic_grid_size or RoPE."
230
+
231
+ pos_embed, rope = None, None
232
+ if self.pos_embed is not None:
233
+ if self.dynamic_grid_size:
234
+ H, W, D = to_3tuple(grid_size)
235
+ prev_grid_size = self.grid_size
236
+ pos_embed = resample_abs_pos_embed(
237
+ self.pos_embed,
238
+ new_size=(H, W, D),
239
+ old_size=prev_grid_size,
240
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
241
+ )
242
+ else:
243
+ pos_embed = self.pos_embed
244
+
245
+ if self.rope is not None:
246
+ B = x.shape[0]
247
+ H, W, D = to_3tuple(grid_size)
248
+ if self.requires_per_sample_rope:
249
+ rope = [self.rope(H=H, W=W, D=D) for _ in range(B)]
250
+ else:
251
+ rope = self.rope(H=H, W=W, D=D)
252
+
253
+ to_cat = []
254
+ if self.cls_token is not None:
255
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
256
+ if self.reg_token is not None:
257
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
258
+
259
+ if self.no_embed_class:
260
+ # deit-3, updated JAX (big vision)
261
+ # position embedding does not overlap with class token, add then concat
262
+ if pos_embed is not None:
263
+ x = x + pos_embed
264
+ if to_cat:
265
+ x = torch.cat(to_cat + [x], dim=1)
266
+ else:
267
+ # original timm, JAX, and deit vit impl
268
+ # pos_embed has entry for class token, concat then add
269
+ if to_cat:
270
+ x = torch.cat(to_cat + [x], dim=1)
271
+ if pos_embed is not None:
272
+ x = x + pos_embed
273
+
274
+ return self.pos_drop(x), rope
275
+
276
+ def forward_features(
277
+ self,
278
+ x: torch.Tensor,
279
+ grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
280
+ ) -> torch.Tensor:
281
+ assert x.ndim == 3, f"Expected input with 3 dimensions (B, N, C), got {x.ndim}."
282
+
283
+ x = self.patch_proj(x)
284
+ x, rope = self._pos_embed(x, grid_size)
285
+ x = self.patch_drop(x)
286
+ x = self.norm_pre(x)
287
+
288
+ for blk in self.blocks:
289
+ x = blk(x, rope=rope)
290
+ x = self.norm(x)
291
+ return x
292
+
293
+ def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
294
+ if self.attn_pool is not None:
295
+ x = self.attn_pool(x)
296
+ return x
297
+ pool_type = self.global_pool if pool_type is None else pool_type
298
+ x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
299
+ return x
300
+
301
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
302
+ x = self.pool(x)
303
+ x = self.fc_norm(x)
304
+ x = self.head_drop(x)
305
+ return x if pre_logits else self.head(x)
306
+
307
+ def forward(
308
+ self,
309
+ x: torch.Tensor,
310
+ grid_size: Optional[Union[int, Tuple[int, int, int]]] = None,
311
+ ) -> torch.Tensor:
312
+ x = self.forward_features(x, grid_size)
313
+ x = self.forward_head(x)
314
+ return x
315
+
316
+ @classmethod
317
+ def from_pretrained(
318
+ cls,
319
+ checkpoint_path_or_url: Union[str, os.PathLike],
320
+ verbose: bool = True,
321
+ **kwargs
322
+ ) -> 'FeatureVisionTransformer':
323
+ """Load pretrained model weights from a local path or a URL."""
324
+ model = cls(**kwargs)
325
+
326
+ def _is_url(path: str) -> bool:
327
+ try:
328
+ parsed = urlparse(str(path))
329
+ return parsed.scheme in ('http', 'https')
330
+ except Exception:
331
+ return False
332
+
333
+ def _is_hf_url(path: str) -> bool:
334
+ try:
335
+ parsed = urlparse(str(path))
336
+ return 'huggingface.co' in parsed.netloc
337
+ except Exception:
338
+ return False
339
+
340
+ if _is_hf_url(checkpoint_path_or_url):
341
+ if verbose:
342
+ print(f"Downloading pretrained weights from Hugging Face URL: {checkpoint_path_or_url}")
343
+ # Extract repo_id and filename from the URL
344
+ parsed = urlparse(checkpoint_path_or_url)
345
+ parts = parsed.path.strip('/').split('/')
346
+ repo_id = '/'.join(parts[:2]) # e.g., 'cclaess/SPECTRE'
347
+ filename = parts[-1] # e.g., 'spectre_backbone_vit_large_patch16_128.pt'
348
+
349
+ local_path = hf_hub_download(repo_id=repo_id, filename=filename)
350
+ state_dict = load_state_dict_from_file(local_path, map_location='cpu')
351
+ elif _is_url(checkpoint_path_or_url):
352
+ if verbose:
353
+ print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
354
+ state_dict = torch.hub.load_state_dict_from_url(
355
+ checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
356
+ else:
357
+ local_path = os.fspath(checkpoint_path_or_url)
358
+ if not os.path.exists(local_path):
359
+ raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
360
+ if verbose:
361
+ print(f"Loading checkpoint from local path: {local_path}")
362
+ state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
363
+
364
+ msg = model.load_state_dict(state_dict, strict=False)
365
+ if verbose:
366
+ print(f"Loaded pretrained weights with msg: {msg}")
367
+ return model
368
+
369
+
370
+ def feat_vit_tiny(
371
+ patch_dim,
372
+ checkpoint_path_or_url: Optional[str] = None,
373
+ **kwargs,
374
+ ) -> FeatureVisionTransformer:
375
+ """Feature ViT-Tiny model.
376
+ """
377
+ kwargs = dict(
378
+ patch_dim=patch_dim,
379
+ embed_dim=192,
380
+ depth=2,
381
+ num_heads=2,
382
+ mlp_ratio=4,
383
+ qkv_bias=True,
384
+ norm_layer=nn.LayerNorm,
385
+ **kwargs,
386
+ )
387
+ if checkpoint_path_or_url is not None:
388
+ return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
389
+ return FeatureVisionTransformer(**kwargs)
390
+
391
+
392
+ def feat_vit_small(
393
+ patch_dim,
394
+ checkpoint_path_or_url: Optional[str] = None,
395
+ **kwargs,
396
+ ) -> FeatureVisionTransformer:
397
+ """Feature ViT-Small model.
398
+ """
399
+ kwargs = dict(
400
+ patch_dim=patch_dim,
401
+ embed_dim=384,
402
+ depth=2,
403
+ num_heads=4,
404
+ mlp_ratio=4,
405
+ qkv_bias=True,
406
+ norm_layer=nn.LayerNorm,
407
+ **kwargs,
408
+ )
409
+ if checkpoint_path_or_url is not None:
410
+ return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
411
+ return FeatureVisionTransformer(**kwargs)
412
+
413
+
414
+ def feat_vit_base(
415
+ patch_dim,
416
+ checkpoint_path_or_url: Optional[str] = None,
417
+ **kwargs,
418
+ ) -> FeatureVisionTransformer:
419
+ """Feature ViT-Base model.
420
+ """
421
+ kwargs = dict(
422
+ patch_dim=patch_dim,
423
+ embed_dim=768,
424
+ depth=2,
425
+ num_heads=8,
426
+ mlp_ratio=4,
427
+ qkv_bias=True,
428
+ norm_layer=nn.LayerNorm,
429
+ **kwargs,
430
+ )
431
+ if checkpoint_path_or_url is not None:
432
+ return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
433
+ return FeatureVisionTransformer(**kwargs)
434
+
435
+
436
+ def feat_vit_large(
437
+ patch_dim,
438
+ checkpoint_path_or_url: Optional[str] = None,
439
+ **kwargs,
440
+ ) -> FeatureVisionTransformer:
441
+ """Feature ViT-Large model.
442
+ """
443
+ kwargs = dict(
444
+ patch_dim=patch_dim,
445
+ embed_dim=1080,
446
+ depth=4,
447
+ num_heads=12,
448
+ mlp_ratio=4,
449
+ qkv_bias=True,
450
+ norm_layer=nn.LayerNorm,
451
+ **kwargs,
452
+ )
453
+ if checkpoint_path_or_url is not None:
454
+ return FeatureVisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs)
455
+ return FeatureVisionTransformer(**kwargs)
spectre/utils/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._utils import (
2
+ fix_random_seeds,
3
+ to_ntuple,
4
+ to_1tuple,
5
+ to_2tuple,
6
+ to_3tuple,
7
+ to_4tuple,
8
+ )
9
+ from .checkpointing import (
10
+ save_state,
11
+ load_state,
12
+ extract_model_from_checkpoint_dinov2,
13
+ extract_model_from_checkpoint_siglip,
14
+ )
15
+ from .collate import (
16
+ extended_collate_dino,
17
+ extended_collate_siglip,
18
+ collate_add_filenames,
19
+ )
20
+ from .config import setup
21
+ from .dataloader import get_dataloader
22
+ from .distributed import (
23
+ is_enabled,
24
+ get_global_size,
25
+ get_global_rank,
26
+ get_local_size,
27
+ get_local_rank,
28
+ init_distributed,
29
+ )
30
+ from .lora import add_lora_adapters
31
+ from .masking import random_block_mask
32
+ from .modeling import (
33
+ deactivate_requires_grad_and_to_eval,
34
+ activate_requires_grad_and_to_train,
35
+ update_momentum,
36
+ update_drop_path_rate,
37
+ repeat_token,
38
+ expand_index_like,
39
+ get_at_index,
40
+ set_at_index,
41
+ mask_at_index,
42
+ mask_bool,
43
+ patchify,
44
+ random_token_mask,
45
+ resample_abs_pos_embed,
46
+ resample_abs_pos_embed_nhwdc,
47
+ resample_patch_embed,
48
+ feature_take_indices,
49
+ global_pool_nlc,
50
+ cat_keep_shapes,
51
+ uncat_with_shapes,
52
+ last_token_pool,
53
+ Format,
54
+ nchwd_to,
55
+ nhwdc_to,
56
+ )
57
+ from .param_groups import get_param_groups_with_decay
58
+ from .scheduler import (
59
+ linear_warmup_schedule,
60
+ cosine_schedule,
61
+ cosine_warmup_schedule,
62
+ CosineWarmupScheduler,
63
+ )
64
+
65
+ __all__ = [
66
+ "fix_random_seeds",
67
+ "to_ntuple",
68
+ "to_1tuple",
69
+ "to_2tuple",
70
+ "to_3tuple",
71
+ "to_4tuple",
72
+ "save_state",
73
+ "load_state",
74
+ "extract_model_from_checkpoint_dinov2",
75
+ "extract_model_from_checkpoint_siglip",
76
+ "extended_collate_dino",
77
+ "extended_collate_siglip",
78
+ "collate_add_filenames",
79
+ "setup",
80
+ "get_dataloader",
81
+ "is_enabled",
82
+ "get_global_size",
83
+ "get_global_rank",
84
+ "get_local_size",
85
+ "get_local_rank",
86
+ "init_distributed",
87
+ "add_lora_adapters",
88
+ "random_block_mask",
89
+ "deactivate_requires_grad_and_to_eval",
90
+ "activate_requires_grad_and_to_train",
91
+ "update_momentum",
92
+ "update_drop_path_rate",
93
+ "repeat_token",
94
+ "expand_index_like",
95
+ "get_at_index",
96
+ "set_at_index",
97
+ "mask_at_index",
98
+ "mask_bool",
99
+ "patchify",
100
+ "random_token_mask",
101
+ "resample_abs_pos_embed",
102
+ "resample_abs_pos_embed_nhwdc",
103
+ "resample_patch_embed",
104
+ "feature_take_indices",
105
+ "global_pool_nlc",
106
+ "cat_keep_shapes",
107
+ "uncat_with_shapes",
108
+ "last_token_pool",
109
+ "Format",
110
+ "nchwd_to",
111
+ "nhwdc_to",
112
+ "get_param_groups_with_decay",
113
+ "linear_warmup_schedule",
114
+ "cosine_schedule",
115
+ "cosine_warmup_schedule",
116
+ "CosineWarmupScheduler",
117
+ ]
spectre/utils/_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Iterable
3
+ from itertools import repeat
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ MONAI_IMPORT_ERROR = None
9
+ try:
10
+ import monai
11
+ except ImportError as e:
12
+ monai = None # type: ignore
13
+ MONAI_IMPORT_ERROR = e
14
+
15
+ def fix_random_seeds(seed: int = 31):
16
+ """
17
+ Fix random seeds.
18
+ """
19
+ if MONAI_IMPORT_ERROR is not None:
20
+ raise ImportError(
21
+ "MONAI is required to use fix_random_seeds but not installed. "
22
+ "Please install MONAI to use this function."
23
+ ) from MONAI_IMPORT_ERROR
24
+
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+ np.random.seed(seed)
28
+ random.seed(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+ monai.utils.set_determinism(seed=seed)
32
+
33
+
34
+ def _ntuple(n: int):
35
+ """
36
+ Helper function to create n-tuple.
37
+ """
38
+ def parse(x):
39
+ if isinstance(x, Iterable) and not isinstance(x, str):
40
+ return tuple(x)
41
+ return tuple(repeat(x, n))
42
+ return parse
43
+
44
+
45
+ to_1tuple = _ntuple(1)
46
+ to_2tuple = _ntuple(2)
47
+ to_3tuple = _ntuple(3)
48
+ to_4tuple = _ntuple(4)
49
+ to_ntuple = _ntuple
spectre/utils/checkpointing.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import warnings
4
+ from typing import Optional, Any
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torch.distributed as dist
9
+
10
+
11
+ def _get_local_rng_state() -> dict:
12
+ """Return a picklable dict with local RNG states (cpu & cuda, numpy, random)."""
13
+ state = {
14
+ "torch": torch.get_rng_state().cpu(),
15
+ "numpy": np.random.get_state(),
16
+ "random": random.getstate(),
17
+ }
18
+
19
+ if torch.cuda.is_available():
20
+ # make sure CUDA states are stored on CPU so they are picklable
21
+ cuda_states = [s.cpu() for s in torch.cuda.get_rng_state_all()]
22
+ state["cuda"] = cuda_states
23
+ else:
24
+ state["cuda"] = None
25
+
26
+ return state
27
+
28
+
29
+ def _set_local_rng_state(state: dict) -> None:
30
+ """Set local RNG states from the dict produced by _get_local_rng_state()."""
31
+ if state is None:
32
+ return
33
+
34
+ if "torch" in state and state["torch"] is not None:
35
+ torch.set_rng_state(state["torch"])
36
+ if "cuda" in state and state["cuda"] is not None and torch.cuda.is_available():
37
+ try:
38
+ # move back to CUDA tensors for this process and set them
39
+ cuda_states = [s.cuda() for s in state["cuda"]]
40
+ torch.cuda.set_rng_state_all(cuda_states)
41
+ except Exception:
42
+ # fallback: try setting per-device RNG if set_rng_state_all fails
43
+ for i, s in enumerate(state["cuda"]):
44
+ try:
45
+ torch.cuda.set_rng_state(s.cuda(), device=i)
46
+ except Exception:
47
+ # ignore if device mismatch
48
+ pass
49
+
50
+ if "numpy" in state and state["numpy"] is not None:
51
+ np.random.set_state(state["numpy"])
52
+ if "random" in state and state["random"] is not None:
53
+ random.setstate(state["random"])
54
+
55
+
56
+ def save_state(ckpt_path: str, epoch: Optional[int] = None, **named_objects: Any) -> None:
57
+ """
58
+ Save a checkpoint that includes:
59
+ - epoch (optional)
60
+ - state_dicts for provided named_objects
61
+ - rng_states: list of per-rank RNG dictionaries (one entry per world rank)
62
+
63
+ If torch.distributed is initialized the RNG states from all ranks are gathered and
64
+ stored in checkpoint["rng_states"] (list indexed by rank). Only rank 0 writes the file.
65
+ In single-process mode the checkpoint contains a single-item rng_states list.
66
+ """
67
+ os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
68
+
69
+ # prepare local RNG state
70
+ local_rng = _get_local_rng_state()
71
+
72
+ # distributed path: gather RNG states from all ranks
73
+ if dist.is_available() and dist.is_initialized():
74
+ rank = dist.get_rank()
75
+ world_size = dist.get_world_size()
76
+ all_states = [None] * world_size
77
+ # gather python objects (picklable)
78
+ dist.all_gather_object(all_states, local_rng)
79
+
80
+ # only rank 0 writes the checkpoint file
81
+ if rank == 0:
82
+ checkpoint = {}
83
+ if epoch is not None:
84
+ checkpoint["epoch"] = epoch
85
+ checkpoint["rng_states"] = all_states
86
+
87
+ # save provided objects' state_dicts (rank 0's local state_dicts)
88
+ for name, obj in named_objects.items():
89
+ checkpoint[name] = obj.state_dict()
90
+
91
+ torch.save(checkpoint, ckpt_path)
92
+
93
+ # ensure everyone waits until rank 0 finished writing
94
+ dist.barrier()
95
+
96
+ else:
97
+ # single-process fallback
98
+ checkpoint = {}
99
+ if epoch is not None:
100
+ checkpoint["epoch"] = epoch
101
+ checkpoint["rng_states"] = [local_rng]
102
+ for name, obj in named_objects.items():
103
+ checkpoint[name] = obj.state_dict()
104
+ torch.save(checkpoint, ckpt_path)
105
+
106
+
107
+ def load_state(ckpt_path: str, **named_objects: Any) -> int:
108
+ """
109
+ Load checkpoint saved by save_state.
110
+
111
+ - Each process loads the same file and restores its own RNG state (checkpoint['rng_states'][rank]).
112
+ - Named objects that exist in the checkpoint will have their state_dict loaded.
113
+ - Returns epoch (int) if present, otherwise 0.
114
+ """
115
+ if not os.path.isfile(ckpt_path):
116
+ warnings.warn(f"Checkpoint file not found: {ckpt_path}")
117
+ return 0
118
+
119
+ # load on all ranks (shared FS assumed)
120
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
121
+ epoch = checkpoint.get("epoch", 0)
122
+
123
+ # load state_dicts into provided objects
124
+ for name, obj in named_objects.items():
125
+ if name in checkpoint:
126
+ try:
127
+ obj.load_state_dict(checkpoint[name])
128
+ except Exception as e:
129
+ warnings.warn(f"Failed to load state_dict for '{name}': {e}")
130
+ else:
131
+ warnings.warn(f"No state_dict found for '{name}' in checkpoint.")
132
+
133
+ # restore this rank's RNG state
134
+ rng_states = checkpoint.get("rng_states", None)
135
+ if rng_states is not None:
136
+ if dist.is_available() and dist.is_initialized():
137
+ rank = dist.get_rank()
138
+ if rank < len(rng_states):
139
+ my_state = rng_states[rank]
140
+ else:
141
+ my_state = None
142
+ else:
143
+ # single-process file: first element
144
+ my_state = rng_states[0] if len(rng_states) > 0 else None
145
+
146
+ try:
147
+ _set_local_rng_state(my_state)
148
+ except Exception as e:
149
+ warnings.warn(f"Failed to restore RNG state: {e}")
150
+
151
+ else:
152
+ warnings.warn("No 'rng_states' found in checkpoint; RNGs not restored.")
153
+
154
+ return epoch
155
+
156
+
157
+ def extract_model_from_checkpoint_dinov2(checkpoint_path: str):
158
+ # Load the checkpoint
159
+ checkpoint = torch.load(
160
+ checkpoint_path,
161
+ weights_only=False,
162
+ map_location="cpu"
163
+ )
164
+
165
+ # Get model state dict
166
+ model_state = checkpoint.get("model", checkpoint)
167
+
168
+ # Create output folder
169
+ output_dir = str(checkpoint_path).replace(".pt", "")
170
+ os.makedirs(output_dir, exist_ok=True)
171
+
172
+ # Quick check: compare the parameters of head_teacher_ibot vs head_teacher_dino
173
+ teacher_dino_keys = [k for k in model_state.keys() if k.startswith("head_teacher_dino.")]
174
+ teacher_ibot_keys = [k for k in model_state.keys() if k.startswith("head_teacher_ibot.")]
175
+
176
+ ibot_separate = True
177
+ if teacher_dino_keys and teacher_ibot_keys:
178
+ if all(torch.equal(model_state[dino_key], model_state[ibot_key]) \
179
+ for dino_key, ibot_key in zip(teacher_dino_keys, teacher_ibot_keys)):
180
+ ibot_separate = False # Same weights → no separate ibot head
181
+
182
+ # Define the components to save
183
+ components = {
184
+ "backbone_teacher.pt": "backbone_teacher.vit",
185
+ "backbone_student.pt": "backbone_student.vit",
186
+ "head_student_dino.pt": "head_student_dino",
187
+ "head_teacher_dino.pt": "head_teacher_dino"
188
+ }
189
+
190
+ # Add ibot heads only if separate
191
+ if ibot_separate:
192
+ components["head_student_ibot.pt"] = "head_student_ibot"
193
+ components["head_teacher_ibot.pt"] = "head_teacher_ibot"
194
+
195
+ # Extract and save each component
196
+ for filename, key in components.items():
197
+ sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
198
+ if not sub_state_dict:
199
+ print(f"[WARNING] No parameters found for {key}, skipping...")
200
+ continue
201
+ torch.save(sub_state_dict, os.path.join(output_dir, filename))
202
+
203
+ print(f"Components extracted to: {output_dir}")
204
+
205
+
206
+ def extract_model_from_checkpoint_siglip(checkpoint_path: str):
207
+ # Load the checkpoint
208
+ checkpoint = torch.load(
209
+ checkpoint_path,
210
+ weights_only=False,
211
+ map_location="cpu",
212
+ )
213
+
214
+ # Get model state dict
215
+ model_state = checkpoint.get("model", checkpoint)
216
+
217
+ # Create output folder
218
+ output_dir = str(checkpoint_path).replace(".pt", "")
219
+ os.makedirs(output_dir, exist_ok=True)
220
+
221
+ # Define the components to save
222
+ components = {
223
+ "backbone_image.pt": "backbone_image",
224
+ "backbone_text.pt": "backbone_text",
225
+ "feature_comb_image.pt": "feature_comb_image",
226
+ "projection_image.pt": "projection_image",
227
+ "projection_text.pt": "projection_text"
228
+ }
229
+
230
+ # Extract and save each component
231
+ for filename, key in components.items():
232
+ sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
233
+ if not sub_state_dict:
234
+ print(f"[WARNING] No parameters found for {key}, skipping...")
235
+ continue
236
+ torch.save(sub_state_dict, os.path.join(output_dir, filename))
237
+
238
+ print(f"Components extracted to: {output_dir}")
spectre/utils/collate.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Callable, Optional
2
+
3
+ import torch
4
+
5
+ MONAI_IMPORT_ERROR = None
6
+ try:
7
+ from monai.data import list_data_collate
8
+ except ImportError as e:
9
+ list_data_collate = lambda x: x # type: ignore
10
+ MONAI_IMPORT_ERROR = e
11
+
12
+
13
+ def extended_collate_dino(samples_list: List) -> dict:
14
+ """
15
+ Applies MONAI's list_data_collate first and then extends it with DINOv2 masking logic.
16
+
17
+ Args:
18
+ samples_list: List of samples containing 'global_crops' and 'local_crops'.
19
+ mask_ratio: Tuple defining the range of masking ratios.
20
+ mask_probability: Probability of applying masking.
21
+ dtype: Data type to cast the collated tensors.
22
+ n_tokens: Number of tokens for masking.
23
+ mask_generator: Function to generate masks.
24
+
25
+ Returns:
26
+ A dictionary with collated global/local crops and corresponding masks.
27
+ """
28
+ if MONAI_IMPORT_ERROR is not None:
29
+ raise ImportError(
30
+ "MONAI is required to use extended_collate_dino but not installed. "
31
+ "Please install MONAI to use this collate function."
32
+ ) from MONAI_IMPORT_ERROR
33
+
34
+ # Apply MONAI's list_data_collate
35
+ collated_data = list_data_collate(samples_list)
36
+
37
+ # Extract crops
38
+ global_views = torch.cat(collated_data["image_global_views"], dim=0)
39
+ local_views = torch.cat(collated_data["image_local_views"], dim=0)
40
+
41
+ return {
42
+ "global_views": global_views,
43
+ "local_views": local_views,
44
+ }
45
+
46
+
47
+ def extended_collate_siglip(
48
+ samples_list: List,
49
+ tokenizer: Optional[Callable] = None,
50
+ tokenizer_padding: bool = True,
51
+ tokenizer_truncation: bool = True,
52
+ tokenizer_max_length: Optional[int] = 1024,
53
+ return_filenames: bool = False
54
+ ) -> dict:
55
+ """
56
+ Applies SigLIP collate and then extends it with tokenization logic.
57
+
58
+ Args:
59
+ samples_list: List of samples containing 'image' and 'report'.
60
+ tokenizer: Tokenizer function to apply on the reports.
61
+
62
+ Returns:
63
+ A dictionary with collated images and tokenized text.
64
+ """
65
+ if MONAI_IMPORT_ERROR is not None:
66
+ raise ImportError(
67
+ "MONAI is required to use extended_collate_siglip but not installed. "
68
+ "Please install MONAI to use this collate function."
69
+ ) from MONAI_IMPORT_ERROR
70
+
71
+ collated_data = list_data_collate(samples_list)
72
+
73
+ if return_filenames:
74
+ if "image" in collated_data.keys():
75
+ if (
76
+ hasattr(samples_list[0]["image"].data, "meta")
77
+ and "filename_or_obj" in samples_list[0]["image"].data.meta
78
+ ):
79
+ collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
80
+
81
+ if tokenizer is not None and "report" in collated_data.keys():
82
+ tokenizer_output = tokenizer.batch_encode_plus(
83
+ collated_data["report"],
84
+ add_special_tokens=True,
85
+ padding=tokenizer_padding,
86
+ truncation=tokenizer_truncation,
87
+ max_length=tokenizer_max_length,
88
+ )
89
+
90
+ collated_data["input_ids"] = torch.tensor(tokenizer_output["input_ids"])
91
+ collated_data["attention_mask"] = torch.tensor(tokenizer_output["attention_mask"])
92
+
93
+ return collated_data
94
+
95
+
96
+ def collate_add_filenames(samples_list: List) -> dict:
97
+ """
98
+ Applies MONAI's list_data_collate and adds filenames to the collated output.
99
+
100
+ Args:
101
+ samples_list: List of samples containing 'image' with metadata.
102
+ Returns:
103
+ A dictionary with collated images and filenames.
104
+ """
105
+ if MONAI_IMPORT_ERROR is not None:
106
+ raise ImportError(
107
+ "MONAI is required to use collate_add_filenames but not installed. "
108
+ "Please install MONAI to use this collate function."
109
+ ) from MONAI_IMPORT_ERROR
110
+
111
+ collated_data = list_data_collate(samples_list)
112
+
113
+ if "image" in collated_data.keys():
114
+ if (
115
+ hasattr(samples_list[0]["image"].data, "meta")
116
+ and "filename_or_obj" in samples_list[0]["image"].data.meta
117
+ ):
118
+ collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
119
+
120
+ return collated_data
spectre/utils/config.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+
4
+ from spectre.utils import _utils, distributed
5
+
6
+ OMEGACONF_IMPORT_ERROR = None
7
+ try:
8
+ from omegaconf import OmegaConf
9
+ except ImportError as e:
10
+ OmegaConf = None # type: ignore
11
+ OMEGACONF_IMPORT_ERROR = e
12
+
13
+
14
+ def apply_scaling_rules_to_cfg(cfg):
15
+ """
16
+ Apply learing rate scaling rules to the configuration object.
17
+ """
18
+ base_lr = cfg.optim.base_lr
19
+ cfg.optim.lr = base_lr
20
+
21
+ # Apply scaling rules
22
+ if cfg.optim.scaling_rule == "constant":
23
+ return cfg
24
+
25
+ try:
26
+ scaling_type, ref_batch_size = cfg.optim.scaling_rule.split("_wrt_")
27
+ ref_batch_size = float(ref_batch_size)
28
+ except ValueError:
29
+ raise NotImplementedError(f"Unknown scaling rule: {cfg.optim.scaling_rule}")
30
+
31
+ scale_factor = cfg.train.batch_size_per_gpu * distributed.get_global_size()
32
+ scale_factor /= ref_batch_size
33
+ scale_factor *= cfg.train.grad_accum_steps
34
+
35
+ if scaling_type == "sqrt":
36
+ cfg.optim.lr *= math.sqrt(scale_factor)
37
+ elif scaling_type == "linear":
38
+ cfg.optim.lr *= scale_factor
39
+ else:
40
+ raise NotImplementedError(f"Unsupported scaling type: {scaling_type}")
41
+
42
+ return cfg
43
+
44
+
45
+ def write_config(cfg, output_dir, name="config.yaml"):
46
+ if OMEGACONF_IMPORT_ERROR is not None:
47
+ raise ImportError(
48
+ "OmegaConf is required to use write_config but not installed. "
49
+ "Please install OmegaConf to use this function."
50
+ ) from OMEGACONF_IMPORT_ERROR
51
+
52
+ saved_cfg_path = os.path.join(output_dir, name)
53
+ with open(saved_cfg_path, "w") as f:
54
+ OmegaConf.save(config=cfg, f=f)
55
+ return saved_cfg_path
56
+
57
+
58
+ def get_cfg_from_args(args, default_config):
59
+ if OMEGACONF_IMPORT_ERROR is not None:
60
+ raise ImportError(
61
+ "OmegaConf is required to use get_cfg_from_args but not installed. "
62
+ "Please install OmegaConf to use this function."
63
+ ) from OMEGACONF_IMPORT_ERROR
64
+
65
+ args.output_dir = os.path.abspath(args.output_dir)
66
+ args.opts = [] if args.opts is None else args.opts
67
+ args.opts += [f"train.output_dir={args.output_dir}"]
68
+ default_cfg = OmegaConf.create(default_config)
69
+ cfg = OmegaConf.load(args.config_file)
70
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
71
+ return cfg
72
+
73
+
74
+ def random_seed(args):
75
+ seed = getattr(args, "seed", 0)
76
+ rank = distributed.get_global_rank()
77
+
78
+ _utils.fix_random_seeds(seed + rank)
79
+
80
+
81
+ def setup(args, default_config):
82
+ """
83
+ Create configs and perform basic setups.
84
+ """
85
+ cfg = get_cfg_from_args(args, default_config)
86
+ os.makedirs(args.output_dir, exist_ok=True)
87
+ random_seed(args)
88
+ accelerator = distributed.init_distributed(cfg)
89
+ apply_scaling_rules_to_cfg(cfg)
90
+ write_config(cfg, args.output_dir)
91
+ return cfg, accelerator
spectre/utils/dataloader.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import Union, Callable, Optional, List
4
+
5
+ import torch
6
+ from torch.utils.data import ConcatDataset
7
+
8
+ MONAI_IMPORT_ERROR = None
9
+ try:
10
+ import monai.data as data
11
+ except ImportError as e:
12
+ data = None # type: ignore
13
+ MONAI_IMPORT_ERROR = e
14
+
15
+
16
+
17
+ def get_dataloader(
18
+ datasets: Union[str, List[str]],
19
+ data_dir: str,
20
+ include_reports: bool = False,
21
+ include_labels: bool = False,
22
+ cache_dataset: bool = False,
23
+ cache_dir: Optional[str] = None,
24
+ use_gds: bool = False,
25
+ transform: Optional[Callable] = None,
26
+ fraction: float = 1.0,
27
+ batch_size: int = 64,
28
+ num_workers: int = 4,
29
+ pin_memory: bool = True,
30
+ shuffle: bool = True,
31
+ collate_fn: Optional[Callable] = None,
32
+ drop_last: bool = True,
33
+ persistent_workers: bool = True,
34
+ use_thread: bool = False,
35
+ ) -> "DataLoader":
36
+ """
37
+ Get dataloader for training.
38
+ """
39
+ if MONAI_IMPORT_ERROR is not None:
40
+ raise ImportError(
41
+ "MONAI is required to use get_dataloader but not installed. "
42
+ "Please install MONAI to use this function."
43
+ ) from MONAI_IMPORT_ERROR
44
+
45
+ if isinstance(datasets, str):
46
+ datasets = [datasets]
47
+
48
+ # Validate constraints
49
+ if include_reports:
50
+ assert set(datasets).issubset({"ct_rate", "merlin", "inspect"}), \
51
+ "When include_reports=True, only 'ct_rate', 'merlin', and 'inspect' are allowed."
52
+ if include_labels:
53
+ assert set(datasets).issubset({"abdomen_atlas", "abdomenct_1k"}), \
54
+ "When include_labels=True, only 'abdomen_atlas' and 'abdomenct_1k' are allowed."
55
+ if use_gds:
56
+ assert cache_dataset, "GDS requires cache_dataset=True."
57
+ assert torch.cuda.is_available(), "GDS requires CUDA to be available."
58
+
59
+ # Dataset configurations
60
+ DATASET_CONFIGS = {
61
+ "ct_rate": {"folder": "CT-RATE", "base_name": "CTRate",
62
+ "extra": {"include_reports": include_reports}},
63
+ "inspect": {"folder": "INSPECT", "base_name": "Inspect",
64
+ "extra": {"include_reports": include_reports}},
65
+ "merlin": {"folder": "MERLIN", "base_name": "Merlin",
66
+ "extra": {"include_reports": include_reports}},
67
+ "nlst": {"folder": "NLST", "base_name": "Nlst"},
68
+ "amos": {"folder": "Amos", "base_name": "Amos"},
69
+ "abdomen_atlas": {"folder": "AbdomenAtlas1.0Mini", "base_name": "AbdomenAtlas",
70
+ "extra": {"include_labels": include_labels}},
71
+ "panorama": {"folder": "PANORAMA", "base_name": "Panorama"},
72
+ "abdomenct_1k": {"folder": "AbdomenCT-1K", "base_name": "AbdomenCT1K",
73
+ "extra": {"include_labels": include_labels}},
74
+ }
75
+
76
+ datasets_list = []
77
+ for ds in datasets:
78
+ if ds.lower() not in DATASET_CONFIGS:
79
+ raise NotImplementedError(f"Dataset {ds} not implemented.")
80
+
81
+ cfg = DATASET_CONFIGS[ds.lower()]
82
+ folder = cfg["folder"]
83
+ extra_args = cfg.get("extra", {})
84
+
85
+ kwargs = {
86
+ "data_dir": os.path.join(data_dir, folder),
87
+ "transform": transform,
88
+ "fraction": fraction,
89
+ **extra_args,
90
+ }
91
+
92
+ base_name = cfg["base_name"]
93
+ class_suffix = "Dataset"
94
+ if cache_dataset:
95
+ class_suffix = "GDSDataset" if use_gds else "PersistentDataset"
96
+
97
+ class_name = f"{base_name}{class_suffix}"
98
+ DatasetClass = getattr(__import__("spectre.data", fromlist=[class_name]), class_name)
99
+
100
+ if cache_dataset:
101
+ kwargs["cache_dir"] = os.path.join(cache_dir, folder)
102
+ if use_gds:
103
+ kwargs["device"] = torch.cuda.current_device()
104
+
105
+ datasets_list.append(DatasetClass(**kwargs))
106
+
107
+ dataset = datasets_list[0] if len(datasets_list) == 1 else ConcatDataset(datasets_list)
108
+
109
+ loader_cls = getattr(data, "ThreadDataLoader" if use_thread else "DataLoader")
110
+ loader_kwargs = {
111
+ "dataset": dataset,
112
+ "batch_size": batch_size,
113
+ "num_workers": num_workers,
114
+ "shuffle": shuffle,
115
+ "drop_last": drop_last,
116
+ }
117
+
118
+ if not use_thread:
119
+ loader_kwargs.update({
120
+ "pin_memory": pin_memory,
121
+ "persistent_workers": persistent_workers
122
+ })
123
+ if collate_fn is not None:
124
+ loader_kwargs["collate_fn"] = collate_fn
125
+
126
+ return loader_cls(**loader_kwargs)
spectre/utils/distributed.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch.distributed as dist
4
+
5
+ ACCELERATE_IMPORT_ERROR = None
6
+ try:
7
+ from accelerate import Accelerator, DataLoaderConfiguration
8
+ except ImportError as e:
9
+ Accelerator = None # type: ignore
10
+ DataLoaderConfiguration = None # type: ignore
11
+ ACCELERATE_IMPORT_ERROR = e
12
+
13
+
14
+ def is_enabled() -> bool:
15
+ """
16
+ Returns:
17
+ True if distributed training is enabled
18
+ """
19
+ return dist.is_available() and dist.is_initialized()
20
+
21
+
22
+ def get_global_size() -> int:
23
+ """
24
+ Returns:
25
+ Number of processes in the distributed group
26
+ """
27
+ if not is_enabled():
28
+ return 1
29
+ return dist.get_world_size()
30
+
31
+
32
+ def get_global_rank() -> int:
33
+ """
34
+ Returns:
35
+ The rank of the current process in the distributed group
36
+ """
37
+ if not is_enabled():
38
+ return 0
39
+ return dist.get_rank()
40
+
41
+
42
+ def get_local_size() -> int:
43
+ """
44
+ Returns:
45
+ Number of processes on the current machine
46
+ """
47
+ if not is_enabled():
48
+ return 1
49
+ return int(os.environ.get("LOCAL_SIZE", 1))
50
+
51
+
52
+ def get_local_rank() -> int:
53
+ """
54
+ Returns:
55
+ The rank of the current process on the current machine
56
+ """
57
+ if not is_enabled():
58
+ return 0
59
+ return int(os.environ.get("LOCAL_RANK", 0))
60
+
61
+
62
+ def init_distributed(cfg):
63
+ """
64
+ Initialize distributed training.
65
+ """
66
+ if ACCELERATE_IMPORT_ERROR is not None:
67
+ raise ImportError(
68
+ "Accelerate is required to use init_distributed but not installed. "
69
+ "Please install Accelerate to use this function."
70
+ ) from ACCELERATE_IMPORT_ERROR
71
+
72
+ # Initialize accelerator
73
+ dataloader_config = DataLoaderConfiguration(
74
+ non_blocking=cfg.train.pin_memory,
75
+ )
76
+ accelerator = Accelerator(
77
+ gradient_accumulation_steps=cfg.train.grad_accum_steps,
78
+ log_with="wandb" if cfg.train.log_wandb else None,
79
+ dataloader_config=dataloader_config,
80
+ )
81
+
82
+ # Initialize wandb
83
+ if cfg.train.log_wandb:
84
+ accelerator.init_trackers(
85
+ project_name="spectre",
86
+ config={k: v for d in cfg.values() for k, v in d.items()},
87
+ init_kwargs={
88
+ "dir": os.path.join(cfg.train.output_dir, "logs"),
89
+ },
90
+ )
91
+
92
+ return accelerator
spectre/utils/lora.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import loralib as lora
3
+
4
+
5
+ def add_lora_adapters(
6
+ root_module: nn.Module,
7
+ r: int = 8,
8
+ lora_alpha: int = 32,
9
+ lora_dropout: float = 0.05,
10
+ target_keywords: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
11
+ ) -> None:
12
+ """
13
+ Recursively traverses the model and replaces every `nn.Linear`
14
+ whose name contains one of `target_keywords` with a LoRA-augmented
15
+ linear layer from loralib.
16
+ """
17
+
18
+ for name, child in list(root_module.named_children()):
19
+ # If the child is itself a container, recurse first
20
+ add_lora_adapters(child, r, lora_alpha, lora_dropout, target_keywords)
21
+
22
+ # Replace target linear layers
23
+ if isinstance(child, nn.Linear) and any(k in name for k in target_keywords):
24
+ lora_layer = lora.Linear( # loralib wrapper
25
+ in_features=child.in_features,
26
+ out_features=child.out_features,
27
+ r=r,
28
+ lora_alpha=lora_alpha,
29
+ lora_dropout=lora_dropout,
30
+ bias=child.bias is not None,
31
+ )
32
+
33
+ # copy original weights so that behaviour is identical pre-training
34
+ lora_layer.weight.data = child.weight.data.clone()
35
+ if child.bias is not None:
36
+ lora_layer.bias.data = child.bias.data.clone()
37
+
38
+ setattr(root_module, name, lora_layer) # hot-swap!
spectre/utils/masking.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+
7
+ def _random_block_mask(
8
+ size: Tuple[int, int, int],
9
+ num_masks: int,
10
+ min_num_masks_per_block: int = 4,
11
+ max_num_masks_per_block: Optional[int] = None,
12
+ max_attempts_per_block: int = 10,
13
+ generator: Optional[torch.Generator] = None,
14
+ device: Optional[Union[torch.device, str]] = None,
15
+ ) -> torch.Tensor:
16
+ """3D helper: generate a (H, W, D) boolean mask by placing cuboidal blocks.
17
+
18
+ - size: (H, W, D)
19
+ - num_masks: target total number of masked voxels for this image
20
+ - min_num_masks_per_block / max_num_masks_per_block: voxel-range per block
21
+ """
22
+ H, W, D = size
23
+ total = H * W * D
24
+ num_masks = min(max(0, int(num_masks)), total)
25
+
26
+ if max_num_masks_per_block is None:
27
+ max_num_masks_per_block = max(1, num_masks)
28
+
29
+ mask = torch.zeros((H, W, D), dtype=torch.bool, device=device)
30
+ masked_count = 0
31
+ global_attempts = 0
32
+
33
+ orders = [(0, 1, 2), (1, 2, 0), (2, 0, 1)]
34
+
35
+ # Try to place blocks until we have enough masked voxels or we exceed attempts
36
+ while masked_count < num_masks and global_attempts < max_attempts_per_block:
37
+ global_attempts += 1
38
+
39
+ # choose target voxels for this block
40
+ target_voxels = int(torch.randint(
41
+ min_num_masks_per_block, max_num_masks_per_block + 1, (1,), generator=generator
42
+ ).item())
43
+
44
+ found = False
45
+ local_attempts = 0
46
+ while not found and local_attempts < max_attempts_per_block:
47
+ local_attempts += 1
48
+
49
+ # random pick order for dims to reduce bias
50
+ order_idx = int(torch.randint(0, 3, (1,), generator=generator).item())
51
+ order = orders[order_idx]
52
+
53
+ # pick first dimension
54
+ if order[0] == 0:
55
+ h = int(torch.randint(1, min(H, target_voxels) + 1, (1,), generator=generator).item())
56
+ elif order[0] == 1:
57
+ w = int(torch.randint(1, min(W, target_voxels) + 1, (1,), generator=generator).item())
58
+ else:
59
+ d = int(torch.randint(1, min(D, target_voxels) + 1, (1,), generator=generator).item())
60
+
61
+ # progressively choose remaining dims while ensuring feasibility
62
+ try:
63
+ if order[0] == 0:
64
+ # h chosen -> pick w then compute d_needed
65
+ max_w = max(1, min(W, target_voxels // h))
66
+ w = int(torch.randint(1, max_w + 1, (1,), generator=generator).item())
67
+ d_needed = math.ceil(target_voxels / (h * w))
68
+ if d_needed <= D:
69
+ d = max(1, d_needed)
70
+ found = True
71
+ elif order[0] == 1:
72
+ # w chosen -> pick d then compute h_needed
73
+ max_d = max(1, min(D, target_voxels // w))
74
+ d = int(torch.randint(1, max_d + 1, (1,), generator=generator).item())
75
+ h_needed = math.ceil(target_voxels / (d * w))
76
+ if h_needed <= H:
77
+ h = max(1, h_needed)
78
+ found = True
79
+ else:
80
+ # d chosen -> pick h then compute w_needed
81
+ max_h = max(1, min(H, target_voxels // d))
82
+ h = int(torch.randint(1, max_h + 1, (1,), generator=generator).item())
83
+ w_needed = math.ceil(target_voxels / (d * h))
84
+ if w_needed <= W:
85
+ w = max(1, w_needed)
86
+ found = True
87
+
88
+ except ValueError:
89
+ # in case of invalid ranges (defensive); just continue trying
90
+ continue
91
+
92
+ # fallback alternative attempt: try simple factorization heuristics
93
+ if not found:
94
+ # attempt small-to-large factorization
95
+ for hh in range(1, min(H, target_voxels) + 1):
96
+ for ww in range(1, min(W, target_voxels // hh) + 1):
97
+ dd = math.ceil(target_voxels / (hh * ww))
98
+ if dd <= D:
99
+ h, w, d = hh, ww, dd
100
+ found = True
101
+ break
102
+ if found:
103
+ break
104
+
105
+ if not found:
106
+ # couldn't find a fitting block this global attempt; move on
107
+ continue
108
+
109
+ # clamp block dims to volume just in case and ensure at least 1
110
+ h = min(max(1, int(h)), H)
111
+ w = min(max(1, int(w)), W)
112
+ d = min(max(1, int(d)), D)
113
+
114
+ # choose random location so block fits
115
+ x0 = int(torch.randint(0, (H - h) + 1, (1,), generator=generator).item()) if H - h > 0 else 0
116
+ y0 = int(torch.randint(0, (W - w) + 1, (1,), generator=generator).item()) if W - w > 0 else 0
117
+ z0 = int(torch.randint(0, (D - d) + 1, (1,), generator=generator).item()) if D - d > 0 else 0
118
+
119
+ mask[x0 : x0 + h, y0 : y0 + w, z0 : z0 + d] = True
120
+ masked_count = int(mask.sum().item())
121
+
122
+ # If still short, fill remaining voxels at random positions
123
+ if masked_count < num_masks:
124
+ remaining = num_masks - masked_count
125
+ indices = torch.nonzero(~mask, as_tuple=False)
126
+ if indices.numel() > 0:
127
+ perm = torch.randperm(indices.shape[0], generator=generator, device=mask.device)
128
+ pick = indices[perm[:remaining]]
129
+ mask[pick[:, 0], pick[:, 1], pick[:, 2]] = True
130
+
131
+ return mask
132
+
133
+
134
+ def random_block_mask(
135
+ size: Tuple[int, int, int, int],
136
+ batch_mask_ratio: float = 0.5,
137
+ min_image_mask_ratio: float = 0.1,
138
+ max_image_mask_ratio: float = 0.5,
139
+ min_num_masks_per_block: int = 4,
140
+ max_num_masks_per_block: Optional[int] = None,
141
+ max_attempts_per_block: int = 10,
142
+ generator: Optional[torch.Generator] = None,
143
+ device: Optional[Union[torch.device, str]] = None,
144
+ ) -> torch.Tensor:
145
+ """Create random block masks for 3D volumes only.
146
+
147
+ Args:
148
+ size: (B, H, W, D)
149
+ batch_mask_ratio: fraction of images in the batch to apply masking to
150
+ min_image_mask_ratio / max_image_mask_ratio: per-image mask fraction range
151
+ min_num_masks_per_block / max_num_masks_per_block: voxels per block range
152
+ max_attempts_per_block: attempts to find a fitting block
153
+ generator: optional torch.Generator for reproducibility.
154
+ device: device for returned tensor
155
+
156
+ Returns:
157
+ boolean tensor with shape (B, H, W, D)
158
+ """
159
+ if len(size) != 4:
160
+ raise ValueError("size must be (B, H, W, D) for 3D masking.")
161
+
162
+ B, H, W, D = size
163
+
164
+ if max_image_mask_ratio < min_image_mask_ratio:
165
+ raise ValueError("max_image_mask_ratio must be >= min_image_mask_ratio.")
166
+
167
+ num_images_masked = int(B * batch_mask_ratio)
168
+ probs = torch.linspace(min_image_mask_ratio, max_image_mask_ratio, num_images_masked + 1).tolist()
169
+
170
+ image_masks = []
171
+ total_voxels = H * W * D
172
+
173
+ for prob_min, prob_max in zip(probs[:-1], probs[1:]):
174
+ # choose number of masked voxels for this image
175
+ u = float(prob_min + (prob_max - prob_min) * torch.rand(1, generator=generator).item())
176
+ num_mask = int(total_voxels * u)
177
+ image_masks.append(
178
+ _random_block_mask(
179
+ size=(H, W, D),
180
+ num_masks=num_mask,
181
+ min_num_masks_per_block=min_num_masks_per_block,
182
+ max_num_masks_per_block=max_num_masks_per_block,
183
+ max_attempts_per_block=max_attempts_per_block,
184
+ generator=generator,
185
+ device=device,
186
+ )
187
+ )
188
+
189
+ # Add non-masked images (all False) to fill the batch
190
+ for _ in range(num_images_masked, B):
191
+ image_masks.append(torch.zeros((H, W, D), dtype=torch.bool, device=device))
192
+
193
+ perm = torch.randperm(B, generator=generator).tolist()
194
+ image_masks = [image_masks[i] for i in perm]
195
+
196
+ return torch.stack(image_masks)
spectre/utils/modeling.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from enum import Enum
5
+ from typing import List, Tuple, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+
12
+
13
+ def deactivate_requires_grad_and_to_eval(model: nn.Module):
14
+ """Deactivates the requires_grad flag for all parameters of a model.
15
+
16
+ This has the same effect as permanently executing the model within a `torch.no_grad()`
17
+ context. Use this method to disable gradient computation and therefore
18
+ training for a model.
19
+
20
+ Examples:
21
+ >>> backbone = resnet18()
22
+ >>> deactivate_requires_grad(backbone)
23
+ """
24
+ for param in model.parameters():
25
+ param.requires_grad = False
26
+ model.eval()
27
+
28
+
29
+ def activate_requires_grad_and_to_train(model: nn.Module):
30
+ """Activates the requires_grad flag for all parameters of a model.
31
+
32
+ Use this method to activate gradients for a model (e.g. after deactivating
33
+ them using `deactivate_requires_grad(...)`).
34
+
35
+ Examples:
36
+ >>> backbone = resnet18()
37
+ >>> activate_requires_grad(backbone)
38
+ """
39
+ for param in model.parameters():
40
+ param.requires_grad = True
41
+ model.train()
42
+
43
+
44
+ @torch.no_grad()
45
+ def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
46
+ """Updates parameters of `model_ema` with Exponential Moving Average of `model`
47
+
48
+ Momentum encoders are a crucial component fo models such as MoCo or BYOL.
49
+
50
+ Examples:
51
+ >>> backbone = resnet18()
52
+ >>> projection_head = MoCoProjectionHead()
53
+ >>> backbone_momentum = copy.deepcopy(moco)
54
+ >>> projection_head_momentum = copy.deepcopy(projection_head)
55
+ >>>
56
+ >>> # update momentum
57
+ >>> update_momentum(moco, moco_momentum, m=0.999)
58
+ >>> update_momentum(projection_head, projection_head_momentum, m=0.999)
59
+ """
60
+ for model_ema, model in zip(model_ema.parameters(), model.parameters()):
61
+ model_ema.data = model_ema.data * m + model.data * (1.0 - m)
62
+
63
+
64
+ def update_drop_path_rate(
65
+ model: "VisionTransformer",
66
+ drop_path_rate: float,
67
+ mode: str = "linear",
68
+ ) -> None:
69
+ """Updates the drop path rate in a VisionTransformer model.
70
+
71
+ Args:
72
+ model:
73
+ VisionTransformer model.
74
+ drop_path_rate:
75
+ Maximum drop path rate.
76
+ mode:
77
+ Drop path rate update mode. Can be "linear" or "uniform". Linear increases
78
+ the drop path rate from 0 to drop_path_rate over the depth of the model.
79
+ Uniform sets the drop path rate to drop_path_rate for all blocks.
80
+ Raises:
81
+ ValueError: If an unknown mode is provided.
82
+ """
83
+ from timm.layers import DropPath
84
+
85
+ total_depth = len(model.blocks)
86
+
87
+ # Determine drop path rates based on the specified mode
88
+ if mode == "linear":
89
+ drop_probabilities = np.linspace(0, drop_path_rate, total_depth)
90
+ elif mode == "uniform":
91
+ drop_probabilities = [drop_path_rate for _ in range(total_depth)]
92
+ else:
93
+ raise ValueError(
94
+ f"Unknown mode: '{mode}', supported modes are 'linear' and 'uniform'."
95
+ )
96
+
97
+ # Update the drop path rate for each block in the model
98
+ for block, drop_prob in zip(model.blocks, drop_probabilities):
99
+ if drop_prob > 0.0:
100
+ block.drop_path1 = DropPath(drop_prob=drop_prob)
101
+ block.drop_path2 = DropPath(drop_prob=drop_prob)
102
+ else:
103
+ block.drop_path1 = nn.Identity()
104
+ block.drop_path2 = nn.Identity()
105
+
106
+
107
+ def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
108
+ """Repeats a token size times.
109
+
110
+ Args:
111
+ token: Token tensor with shape (1, 1, dim).
112
+ size: (batch_size, sequence_length) tuple.
113
+
114
+ Returns:
115
+ Tensor with shape (batch_size, sequence_length, dim) containing copies
116
+ of the input token.
117
+ """
118
+ batch_size, sequence_length = size
119
+ return token.repeat(batch_size, sequence_length, 1)
120
+
121
+
122
+ def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
123
+ """Expands the index along the last dimension of the input tokens.
124
+
125
+ Args:
126
+ index:
127
+ Index tensor with shape (batch_size, idx_length) where each entry is
128
+ an index in [0, sequence_length).
129
+ tokens:
130
+ Tokens tensor with shape (batch_size, sequence_length, dim).
131
+
132
+ Returns:
133
+ Index tensor with shape (batch_size, idx_length, dim) where the original
134
+ indices are repeated dim times along the last dimension.
135
+ """
136
+ dim = tokens.shape[-1]
137
+ index = index.unsqueeze(-1).expand(-1, -1, dim)
138
+ return index
139
+
140
+
141
+ def get_at_index(tokens: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
142
+ """Selects tokens at index.
143
+
144
+ Args:
145
+ tokens:
146
+ Token tensor with shape (batch_size, sequence_length, dim).
147
+ index:
148
+ Index tensor with shape (batch_size, index_length) where each entry is
149
+ an index in [0, sequence_length).
150
+
151
+ Returns:
152
+ Token tensor with shape (batch_size, index_length, dim) containing the
153
+ selected tokens.
154
+ """
155
+ index = expand_index_like(index, tokens)
156
+ return torch.gather(tokens, 1, index)
157
+
158
+
159
+ def set_at_index(
160
+ tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor
161
+ ) -> torch.Tensor:
162
+ """Copies all values into the input tensor at the given indices.
163
+
164
+ Args:
165
+ tokens: Tokens tensor with shape (batch_size, sequence_length, dim).
166
+ index: Index tensor with shape (batch_size, index_length).
167
+ value: Value tensor with shape (batch_size, index_length, dim).
168
+
169
+ Returns:
170
+ Tokens tensor with shape (batch_size, sequence_length, dim) containing
171
+ the new values.
172
+ """
173
+ index = expand_index_like(index, tokens)
174
+ return torch.scatter(tokens, 1, index, value)
175
+
176
+
177
+ def mask_at_index(
178
+ tokens: torch.Tensor, index: torch.Tensor, mask_token: torch.Tensor
179
+ ) -> torch.Tensor:
180
+ """Copies mask token into the input tensor at the given indices.
181
+
182
+ Args:
183
+ tokens:
184
+ Tokens tensor with shape (batch_size, sequence_length, dim).
185
+ index:
186
+ Index tensor with shape (batch_size, index_length).
187
+ mask_token:
188
+ Value tensor with shape (1, 1, dim).
189
+
190
+ Returns:
191
+ Tokens tensor with shape (batch_size, sequence_length, dim) containing
192
+ the new values.
193
+
194
+ """
195
+ mask = tokens.new_zeros(tokens.shape)
196
+ mask = set_at_index(mask, index, 1)
197
+ return (1 - mask) * tokens + mask * mask_token
198
+
199
+
200
+ def mask_bool(tokens: torch.Tensor, mask: torch.Tensor, mask_token: torch.Tensor) -> torch. Tensor:
201
+ """Returns a tensor with tokens replaced by the mask tokens in all positions where
202
+ the mask is True.
203
+
204
+ Args:
205
+ tokens:
206
+ Tokens tensor with shape (batch_size, sequence_length, dim).
207
+ mask:
208
+ Boolean mask tensor with shape (batch_size, sequence_length).
209
+ mask_token:
210
+ Mask token with shape (1, 1, dim).
211
+
212
+ Returns:
213
+ Tokens tensor with shape (batch_size, sequence_length, dim) where tokens[i, j]
214
+ is replaced by the mask token if mask[i, j] is True.
215
+ """
216
+ # Convert to int for multiplication.
217
+ mask = mask.unsqueeze(-1).to(torch.bool).to(torch.int)
218
+ return (1 - mask) * tokens + mask * mask_token
219
+
220
+
221
+ def patchify(images: torch.Tensor, patch_size: Tuple[int, int, int]) -> torch.Tensor:
222
+ """Converts a batch of input images into patches.
223
+
224
+ Args:
225
+ images:
226
+ Images tensor with shape (batch_size, channels, height, width, depth)
227
+ patch_size:
228
+ Patch size in pixels. Image width and height must be multiples of
229
+ the patch size.
230
+
231
+ Returns:
232
+ Patches tensor with shape (batch_size, num_patches, channels * math.prod(patch_size))
233
+ where num_patches = image_width / patch_size * image_height / patch_size.
234
+
235
+ """
236
+ N, C, H, W, D = images.shape
237
+ assert (
238
+ H % patch_size[0] == 0
239
+ and W % patch_size[1] == 0
240
+ and D % patch_size[2] == 0
241
+ ), "Image height, width, and depth must be multiples of the patch size."
242
+
243
+ patch_h = H // patch_size[0]
244
+ patch_w = W // patch_size[1]
245
+ patch_d = D // patch_size[2]
246
+
247
+ num_patches = patch_h * patch_w * patch_d
248
+ patches = images.reshape(shape=(
249
+ N, C,
250
+ patch_h, patch_size[0],
251
+ patch_w, patch_size[1],
252
+ patch_d, patch_size[2],
253
+ ))
254
+ patches = torch.einsum("nchpwqdr->nhwdpqrc", patches)
255
+ patches = patches.reshape(shape=(N, num_patches, math.prod(patch_size) * C))
256
+ return patches
257
+
258
+
259
+ def random_token_mask(
260
+ size: Tuple[int, int],
261
+ mask_ratio: float = 0.6,
262
+ mask_class_token: bool = False,
263
+ device: Optional[Union[torch.device, str]] = None,
264
+ ) -> torch.Tensor:
265
+ """Creates random token masks.
266
+
267
+ Args:
268
+ size:
269
+ Size of the token batch for which to generate masks.
270
+ Should be (batch_size, sequence_length).
271
+ mask_ratio:
272
+ Percentage of tokens to mask.
273
+ mask_class_token:
274
+ If False the class token is never masked. If True the class token
275
+ might be masked.
276
+ device:
277
+ Device on which to create the index masks.
278
+
279
+ Returns:
280
+ A (index_keep, index_mask) tuple where each index is a tensor.
281
+ index_keep contains the indices of the unmasked tokens and has shape
282
+ (batch_size, num_keep). index_mask contains the indices of the masked
283
+ tokens and has shape (batch_size, sequence_length - num_keep).
284
+ num_keep is equal to sequence_length * (1- mask_ratio).
285
+
286
+ """
287
+ batch_size, sequence_length = size
288
+ num_keep = int(sequence_length * (1 - mask_ratio))
289
+
290
+ noise = torch.rand(batch_size, sequence_length, device=device)
291
+ if not mask_class_token and sequence_length > 0:
292
+ # make sure that class token is not masked
293
+ noise[:, 0] = -1
294
+ num_keep = max(1, num_keep)
295
+
296
+ # get indices of tokens to keep
297
+ indices = torch.argsort(noise, dim=1)
298
+ idx_keep = indices[:, :num_keep]
299
+ idx_mask = indices[:, num_keep:]
300
+
301
+ return idx_keep, idx_mask
302
+
303
+
304
+ def resample_abs_pos_embed(
305
+ posemb: torch.Tensor,
306
+ new_size: List[int],
307
+ old_size: List[int],
308
+ num_prefix_tokens: int = 1,
309
+ interpolation: str = 'trilinear',
310
+ ):
311
+ # sort out sizes, assume square if old size not provided
312
+ num_pos_tokens = posemb.shape[1]
313
+ num_new_tokens = new_size[0] * new_size[1] * new_size[2] + num_prefix_tokens
314
+ if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
315
+ return posemb
316
+
317
+ if num_prefix_tokens:
318
+ posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
319
+ else:
320
+ posemb_prefix, posemb = None, posemb
321
+
322
+ # do the interpolation
323
+ embed_dim = posemb.shape[-1]
324
+ orig_dtype = posemb.dtype
325
+ posemb = posemb.float() # interpolate needs float32
326
+ posemb = posemb.reshape(1, old_size[0], old_size[1], old_size[2], -1).permute(0, 4, 1, 2, 3)
327
+ posemb = F.interpolate(posemb, size=new_size, mode=interpolation)
328
+ posemb = posemb.permute(0, 2, 3, 4, 1).reshape(1, -1, embed_dim)
329
+ posemb = posemb.to(orig_dtype)
330
+
331
+ # add back extra (class, etc) prefix tokens
332
+ if posemb_prefix is not None:
333
+ posemb = torch.cat([posemb_prefix, posemb], dim=1)
334
+
335
+ return posemb
336
+
337
+
338
+ def resample_abs_pos_embed_nhwdc(
339
+ posemb: torch.Tensor,
340
+ new_size: List[int],
341
+ interpolation: str = 'trilinear',
342
+ ):
343
+ if new_size[0] == posemb.shape[-4] and new_size[1] == posemb.shape[-3] and new_size[2] == posemb.shape[-2]:
344
+ return posemb
345
+
346
+ orig_dtype = posemb.dtype
347
+ posemb = posemb.float()
348
+ posemb = posemb.reshape(1, posemb.shape[-4], posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 4, 1, 2, 3)
349
+ posemb = F.interpolate(posemb, size=new_size, mode=interpolation)
350
+ posemb = posemb.permute(0, 2, 3, 4, 1).to(orig_dtype)
351
+
352
+ return posemb
353
+
354
+
355
+ def resample_patch_embed(
356
+ patch_embed,
357
+ new_size: List[int],
358
+ interpolation: str = 'trilinear',
359
+ ):
360
+ """Resample the weights of the patch embedding kernel to target resolution.
361
+ We resample the patch embedding kernel by approximately inverting the effect
362
+ of patch resizing.
363
+
364
+ Code based on:
365
+ https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
366
+
367
+ With this resizing, we can for example load a B/8 filter into a B/16 model
368
+ and, on 2x larger input image, the result will match.
369
+
370
+ Args:
371
+ patch_embed: original parameter to be resized.
372
+ new_size (tuple[int, int, int]): target shape (depth, height, width).
373
+ interpolation (str): interpolation for resize
374
+ Returns:
375
+ Resized patch embedding kernel.
376
+ """
377
+ import numpy as np
378
+ try:
379
+ from torch import vmap
380
+ except ImportError:
381
+ from functorch import vmap
382
+
383
+ assert len(patch_embed.shape) == 5, "Five dimensions expected"
384
+ assert len(new_size) == 3, "New shape should only be (height, width, depth)"
385
+ old_size = patch_embed.shape[-3:]
386
+ if tuple(old_size) == tuple(new_size):
387
+ return patch_embed
388
+
389
+ def resize(x_np, _new_size):
390
+ x_tf = torch.Tensor(x_np)[None, None, ...]
391
+ x_upsampled = F.interpolate(
392
+ x_tf, size=_new_size, mode=interpolation)[0, 0, ...].numpy()
393
+ return x_upsampled
394
+
395
+ def get_resize_mat(_old_size, _new_size):
396
+ mat = []
397
+ for i in range(np.prod(_old_size)):
398
+ basis_vec = np.zeros(_old_size)
399
+ basis_vec[np.unravel_index(i, _old_size)] = 1.
400
+ mat.append(resize(basis_vec, _new_size).reshape(-1))
401
+ return np.stack(mat).T
402
+
403
+ resize_mat = get_resize_mat(old_size, new_size)
404
+ resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
405
+
406
+ def resample_kernel(kernel):
407
+ resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
408
+ return resampled_kernel.reshape(new_size)
409
+
410
+ v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
411
+ orig_dtype = patch_embed.dtype
412
+ patch_embed = patch_embed.float()
413
+ patch_embed = v_resample_kernel(patch_embed)
414
+ patch_embed = patch_embed.to(orig_dtype)
415
+ return patch_embed
416
+
417
+
418
+ def feature_take_indices(
419
+ num_features: int,
420
+ indices: Optional[Union[int, List[int]]] = None,
421
+ as_set: bool = False,
422
+ ) -> Tuple[List[int], int]:
423
+ """ Determine the absolute feature indices to 'take' from.
424
+
425
+ Note: This function can be called in forward() so must be torchscript compatible,
426
+ which requires some incomplete typing and workaround hacks.
427
+
428
+ Args:
429
+ num_features: total number of features to select from
430
+ indices: indices to select,
431
+ None -> select all
432
+ int -> select last n
433
+ list/tuple of int -> return specified (-ve indices specify from end)
434
+ as_set: return as a set
435
+
436
+ Returns:
437
+ List (or set) of absolute (from beginning) indices, Maximum index
438
+ """
439
+ if indices is None:
440
+ indices = num_features # all features if None
441
+
442
+ if isinstance(indices, int):
443
+ # convert int -> last n indices
444
+ assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})'
445
+ take_indices = [num_features - indices + i for i in range(indices)]
446
+ else:
447
+ take_indices: List[int] = []
448
+ for i in indices:
449
+ idx = num_features + i if i < 0 else i
450
+ assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})'
451
+ take_indices.append(idx)
452
+
453
+ if not torch.jit.is_scripting() and as_set:
454
+ return set(take_indices), max(take_indices)
455
+
456
+ return take_indices, max(take_indices)
457
+
458
+
459
+ def global_pool_nlc(
460
+ x: torch.Tensor,
461
+ pool_type: str = 'token',
462
+ num_prefix_tokens: int = 1,
463
+ reduce_include_prefix: bool = False,
464
+ ):
465
+ if not pool_type:
466
+ return x
467
+
468
+ if pool_type == 'token':
469
+ x = x[:, 0] # class token
470
+ else:
471
+ x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
472
+ if pool_type == 'avg':
473
+ x = x.mean(dim=1)
474
+ elif pool_type == 'avgmax':
475
+ x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
476
+ elif pool_type == 'max':
477
+ x = x.amax(dim=1)
478
+ else:
479
+ assert not pool_type, f'Unknown pool type {pool_type}'
480
+
481
+ return x
482
+
483
+
484
+ def cat_keep_shapes(
485
+ x_list: List[torch.Tensor]
486
+ ) -> Tuple[torch.Tensor, List[Tuple[int, ...]], List[int]]:
487
+ if not x_list:
488
+ return torch.empty(0), [], []
489
+
490
+ shapes = [x.shape for x in x_list]
491
+ num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
492
+ x_cat = torch.cat([x.flatten(0, -2) for x in x_list], dim=0)
493
+
494
+ return x_cat, shapes, num_tokens
495
+
496
+
497
+ def uncat_with_shapes(
498
+ x_cat: torch.Tensor,
499
+ shapes: List[Tuple[int, ...]],
500
+ num_tokens: List[int]
501
+ ) -> List[torch.Tensor]:
502
+ if not shapes:
503
+ return []
504
+
505
+ x_splitted = torch.split_with_sizes(x_cat, num_tokens, dim=0)
506
+ shapes_adjusted = [shape[:-1] + torch.Size([x_cat.shape[-1]]) for shape in shapes]
507
+ outputs_reshape = [x.reshape(shape) for x, shape in zip(x_splitted, shapes_adjusted)]
508
+
509
+ return outputs_reshape
510
+
511
+
512
+ def last_token_pool(
513
+ last_hidden_states: torch.Tensor,
514
+ attention_mask: torch.Tensor
515
+ ) -> torch.Tensor:
516
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
517
+ if left_padding:
518
+ return last_hidden_states[:, -1]
519
+ else:
520
+ sequence_lengths = attention_mask.sum(dim=1) - 1
521
+ batch_size = last_hidden_states.shape[0]
522
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device),
523
+ sequence_lengths]
524
+
525
+
526
+ class Format(str, Enum):
527
+ NCHWD = 'NCHWD'
528
+ NHWDC = 'NHWDC'
529
+ NCL = 'NCL'
530
+ NLC = 'NLC'
531
+
532
+
533
+ def nchwd_to(x: torch.Tensor, fmt: Format):
534
+ if fmt == Format.NHWDC:
535
+ x = x.permute(0, 2, 3, 4, 1)
536
+ elif fmt == Format.NLC:
537
+ x = x.flatten(2).transpose(1, 2)
538
+ elif fmt == Format.NCL:
539
+ x = x.flatten(2)
540
+ return x
541
+
542
+
543
+ def nhwdc_to(x: torch.Tensor, fmt: Format):
544
+ if fmt == Format.NCHWD:
545
+ x = x.permute(0, 4, 1, 2, 3)
546
+ elif fmt == Format.NLC:
547
+ x = x.flatten(1, 2)
548
+ elif fmt == Format.NCL:
549
+ x = x.flatten(1, 2).transpose(1, 2)
550
+ return x
spectre/utils/param_groups.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_vit_lr_decay_rate(
5
+ name: str,
6
+ llrd_factor: float = 1.0,
7
+ num_layers: int = 12,
8
+ force_is_backbone: bool = False,
9
+ shift: int = 0,
10
+ ) -> float:
11
+ """
12
+ Get the layer-wise learning rate decay (LLRD) rate for a given parameter name.
13
+
14
+ Args:
15
+ name:
16
+ The name of the parameter.
17
+ llrd_factor:
18
+ The decay factor for each layer.
19
+ num_layers:
20
+ The total number of layers in the model.
21
+ force_is_backbone:
22
+ If True, forces the function to treat the parameter as part of the backbone.
23
+ shift:
24
+ An integer to shift the layer ids, useful when combining multiple modules.
25
+
26
+ Returns:
27
+ The learning rate multiplier for the parameter.
28
+ """
29
+ layer_id = num_layers + 1
30
+ if name.startswith("backbone") or force_is_backbone:
31
+ if (
32
+ ".pos_embed" in name
33
+ or ".patch_embed" in name
34
+ or ".patch_proj" in name
35
+ or ".mask_token" in name
36
+ or ".cls_token" in name
37
+ or ".reg_token" in name
38
+ ):
39
+ layer_id = 0
40
+ elif ".blocks." in name:
41
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + shift
42
+
43
+ return llrd_factor ** (num_layers + 1 - layer_id)
44
+
45
+
46
+ def get_param_groups_with_decay(
47
+ model: nn.Module,
48
+ llrd_factor: float = 1.0,
49
+ patch_embed_lr_mult: float = 1.0,
50
+ projection_head_wd_mult: float = 1.0,
51
+ lora_lr_factor: float = 1.0,
52
+ num_layers: int | None = None,
53
+ ):
54
+
55
+ force_is_backbone = False
56
+ shift = 0
57
+ if num_layers is not None:
58
+ num_layers = num_layers
59
+ elif hasattr(model, "n_blocks"):
60
+ num_layers = model.n_blocks
61
+ force_is_backbone = True
62
+ elif hasattr(model, "blocks"):
63
+ num_layers = len(model.blocks)
64
+ force_is_backbone = True
65
+ elif hasattr(model, "backbone") and hasattr(model.backbone, "blocks"):
66
+ num_layers = len(model.backbone.blocks)
67
+ elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "blocks"): # DINO specific
68
+ num_layers = len(model.backbone_student.blocks)
69
+ elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "vit") and hasattr(model.backbone_student.vit, "blocks"): # DINOv2 specific
70
+ num_layers = len(model.backbone_student.vit.blocks)
71
+ elif hasattr(model, "backbone_image") and hasattr(model.backbone_image, "blocks"): # SigLIP specific
72
+ if not hasattr(model, "feature_comb_image") or model.feature_comb_image is None:
73
+ num_layers = len(model.backbone_image.blocks)
74
+ else:
75
+ num_layers = len(model.backbone_image.blocks) + len(model.feature_comb_image.blocks)
76
+ shift = len(model.backbone_image.blocks)
77
+ force_is_backbone = True
78
+ else:
79
+ num_layers = 0
80
+
81
+ all_param_groups = []
82
+ for n, p in model.named_parameters():
83
+ if not p.requires_grad:
84
+ continue
85
+ if not "lora_" in n:
86
+ s = shift if "feature_comb" in n else 0
87
+ llrd_rate = get_vit_lr_decay_rate(
88
+ n, llrd_factor, num_layers, force_is_backbone, s,
89
+ )
90
+
91
+ d = {
92
+ "name": n,
93
+ "params": p,
94
+ "lr_mult": llrd_rate,
95
+ "wd_mult": 1.0,
96
+ }
97
+
98
+ if "head" in n or "projection" in n:
99
+ d["wd_mult"] = projection_head_wd_mult
100
+
101
+ # No weight-decay on biases, norm parameters, layer scale gamma, learned tokens and embeddings
102
+ if n.endswith("bias") or "norm" in n or "gamma" in n or "fourrier_w" in n:
103
+ d["wd_mult"] = 0.0
104
+
105
+ if "patch_embed" in n:
106
+ d["lr_mult"] *= patch_embed_lr_mult
107
+
108
+ else:
109
+ # LoRA parameters
110
+ d = {
111
+ "name": n,
112
+ "params": p,
113
+ "lr_mult": lora_lr_factor,
114
+ "wd_mult": 1.0,
115
+ }
116
+
117
+ all_param_groups.append(d)
118
+ return all_param_groups
spectre/utils/scheduler.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def linear_warmup_schedule(
9
+ step: int,
10
+ warmup_steps: int,
11
+ start_value: float,
12
+ end_value: float,
13
+ ) -> float:
14
+ if warmup_steps < 0:
15
+ raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
16
+ if step < 0:
17
+ raise ValueError(f"Current step number {step} can't be negative.")
18
+ if start_value < 0:
19
+ raise ValueError(f"Start value {start_value} can't be negative.")
20
+ if end_value <= 0:
21
+ raise ValueError(f"End value {end_value} can't be non-positive.")
22
+ if start_value > end_value:
23
+ raise ValueError(
24
+ f"Start value {start_value} must be less than or equal to end value {end_value}."
25
+ )
26
+ if step < warmup_steps:
27
+ return start_value + step / warmup_steps * (end_value - start_value)
28
+ else:
29
+ return end_value
30
+
31
+
32
+ def cosine_schedule(
33
+ step: int,
34
+ max_steps: int,
35
+ start_value: float,
36
+ end_value: float,
37
+ period: Optional[int] = None,
38
+ ) -> float:
39
+ """Use cosine decay to gradually modify start_value to reach target end_value.
40
+
41
+ Args:
42
+ step:
43
+ Current step number.
44
+ max_steps:
45
+ Total number of steps.
46
+ start_value:
47
+ Starting value.
48
+ end_value:
49
+ Target value.
50
+ period:
51
+ The number of steps over which the cosine function completes a full cycle.
52
+ Defaults to max_steps.
53
+
54
+ Returns:
55
+ Cosine decay value.
56
+
57
+ """
58
+ if step < 0:
59
+ raise ValueError(f"Current step number {step} can't be negative.")
60
+ if max_steps < 1:
61
+ raise ValueError(f"Total step number {max_steps} must be >= 1.")
62
+ if period is None and step > max_steps:
63
+ warnings.warn(
64
+ f"Current step number {step} exceeds max_steps {max_steps}.",
65
+ category=RuntimeWarning,
66
+ )
67
+ if period is not None and period <= 0:
68
+ raise ValueError(f"Period {period} must be >= 1")
69
+
70
+ decay: float
71
+ if period is not None: # "cycle" based on period, if provided
72
+ decay = (
73
+ end_value
74
+ - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
75
+ )
76
+ elif max_steps == 1:
77
+ # Avoid division by zero
78
+ decay = end_value
79
+ elif step == max_steps:
80
+ # Special case for Pytorch Lightning which updates LR scheduler also for epoch
81
+ # after last training epoch.
82
+ decay = end_value
83
+ else:
84
+ decay = (
85
+ end_value
86
+ - (end_value - start_value)
87
+ * (np.cos(np.pi * step / (max_steps - 1)) + 1)
88
+ / 2
89
+ )
90
+ return decay
91
+
92
+
93
+ def cosine_warmup_schedule(
94
+ step: int,
95
+ max_steps: int,
96
+ start_value: float,
97
+ end_value: float,
98
+ warmup_steps: int,
99
+ warmup_start_value: float,
100
+ warmup_end_value: Optional[float] = None,
101
+ period: Optional[int] = None,
102
+ ) -> float:
103
+ """Use cosine decay to gradually modify start_value to reach target end_value.
104
+
105
+ Uses linear warmup for the first warmup_steps steps.
106
+
107
+ Args:
108
+ step:
109
+ Current step number.
110
+ max_steps:
111
+ Total number of steps.
112
+ start_value:
113
+ Starting value.
114
+ end_value:
115
+ Target value.
116
+ warmup_steps:
117
+ Number of steps for warmup.
118
+ warmup_start_value:
119
+ Starting value for warmup.
120
+ warmup_end_value:
121
+ Target value for warmup. Defaults to start_value.
122
+ period:
123
+ The number of steps over which the cosine function completes a full cycle.
124
+ Defaults to max_steps - warmup_steps.
125
+ Returns:
126
+ Cosine decay value.
127
+ """
128
+ if warmup_steps < 0:
129
+ raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
130
+ if warmup_steps > max_steps:
131
+ raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps.")
132
+ if step > max_steps:
133
+ warnings.warn(
134
+ f"Current step number {step} exceeds max_steps {max_steps}.",
135
+ category=RuntimeWarning,
136
+ )
137
+
138
+ if warmup_end_value is None:
139
+ warmup_end_value = start_value
140
+
141
+ if step < warmup_steps:
142
+ # Use step + 1 to reach warmup_end_value at end of warmup. This means that the
143
+ # initial warmup_start_value is skipped which is oftentimes desired when setting
144
+ # it to 0 as this would result in no parameter updates.
145
+ return (
146
+ warmup_start_value
147
+ + (warmup_end_value - warmup_start_value) * (step + 1) / warmup_steps
148
+ )
149
+ else:
150
+ max_steps = max_steps - warmup_steps if period is None else 1
151
+ return cosine_schedule(
152
+ step=step - warmup_steps,
153
+ max_steps=max_steps,
154
+ start_value=start_value,
155
+ end_value=end_value,
156
+ period=period,
157
+ )
158
+
159
+
160
+ class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):
161
+ """Cosine warmup scheduler for learning rate.
162
+
163
+ Args:
164
+ optimizer:
165
+ Optimizer object to schedule the learning rate.
166
+ warmup_epochs:
167
+ Number of warmup epochs or steps.
168
+ max_epochs:
169
+ Total number of training epochs or steps.
170
+ last_epoch:
171
+ The index of last epoch or step.
172
+ start_value:
173
+ Starting learning rate.
174
+ end_value:
175
+ Target learning rate.
176
+ verbose:
177
+ If True, prints a message to stdout for each update.
178
+ warmup_start_value:
179
+ Starting learning rate for warmup.
180
+ warmup_end_value:
181
+ Target learning rate for warmup. Defaults to start_value.
182
+
183
+ Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index
184
+ can be used. The naming follows the PyTorch convention to use `epoch` for the steps
185
+ in the scheduler.
186
+ """
187
+
188
+ def __init__(
189
+ self,
190
+ optimizer: torch.optim.Optimizer,
191
+ warmup_epochs: int,
192
+ max_epochs: int,
193
+ last_epoch: int = -1,
194
+ start_value: float = 1.0,
195
+ end_value: float = 0.001,
196
+ period: Optional[int] = None,
197
+ verbose: bool = False,
198
+ warmup_start_value: float = 0.0,
199
+ warmup_end_value: Optional[float] = None,
200
+ ) -> None:
201
+ self.warmup_epochs = warmup_epochs
202
+ self.max_epochs = max_epochs
203
+ self.start_value = start_value
204
+ self.end_value = end_value
205
+ self.period = period
206
+ self.warmup_start_value = warmup_start_value
207
+ self.warmup_end_value = warmup_end_value
208
+
209
+ super().__init__(
210
+ optimizer=optimizer,
211
+ lr_lambda=self.scale_lr,
212
+ last_epoch=last_epoch,
213
+ verbose=verbose,
214
+ )
215
+
216
+ def scale_lr(self, epoch: int) -> float:
217
+ """Scale learning rate according to the current epoch number.
218
+
219
+ Args:
220
+ epoch:
221
+ Current epoch number.
222
+
223
+ Returns:
224
+ Scaled learning rate.
225
+
226
+ """
227
+ return cosine_warmup_schedule(
228
+ step=epoch,
229
+ max_steps=self.max_epochs,
230
+ start_value=self.start_value,
231
+ end_value=self.end_value,
232
+ warmup_steps=self.warmup_epochs,
233
+ warmup_start_value=self.warmup_start_value,
234
+ warmup_end_value=self.warmup_end_value,
235
+ period=self.period,
236
+ )