Feature Extraction
Transformers
Safetensors
English
spectre
medical-imaging
ct-scan
3d
vision-transformer
self-supervised-learning
foundation-model
radiology
custom_code
Instructions to use cclaess/SPECTRE-Large with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use cclaess/SPECTRE-Large with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="cclaess/SPECTRE-Large", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("cclaess/SPECTRE-Large", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Initial commit
Browse files- spectre/models/eomt.py +0 -230
- spectre/models/resnet.py +0 -726
- spectre/models/seomt.py +0 -394
- spectre/models/upsample_anything.py +0 -319
- spectre/utils/checkpointing.py +0 -238
- spectre/utils/collate.py +0 -120
- spectre/utils/config.py +0 -91
- spectre/utils/dataloader.py +0 -126
- spectre/utils/distributed.py +0 -92
- spectre/utils/lora.py +0 -38
- spectre/utils/masking.py +0 -196
- spectre/utils/param_groups.py +0 -118
- spectre/utils/scheduler.py +0 -236
spectre/models/eomt.py
DELETED
|
@@ -1,230 +0,0 @@
|
|
| 1 |
-
# Adapted from https://github.com/tue-mps/eomt/
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import math
|
| 5 |
-
from typing import Optional, Tuple, Union
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
|
| 11 |
-
from spectre.models.layers import LayerNorm3d
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class ScaleBlock(nn.Module):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
embed_dim: int,
|
| 18 |
-
scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
|
| 19 |
-
conv1_layer: nn.Module = nn.ConvTranspose3d,
|
| 20 |
-
):
|
| 21 |
-
super().__init__()
|
| 22 |
-
|
| 23 |
-
self.conv1 = conv1_layer(
|
| 24 |
-
embed_dim,
|
| 25 |
-
embed_dim,
|
| 26 |
-
kernel_size=scale_factors,
|
| 27 |
-
stride=scale_factors,
|
| 28 |
-
)
|
| 29 |
-
self.act = nn.GELU()
|
| 30 |
-
self.conv2 = nn.Conv3d(
|
| 31 |
-
embed_dim,
|
| 32 |
-
embed_dim,
|
| 33 |
-
kernel_size=3,
|
| 34 |
-
padding=1,
|
| 35 |
-
groups=embed_dim,
|
| 36 |
-
bias=False,
|
| 37 |
-
)
|
| 38 |
-
self.norm = LayerNorm3d(embed_dim)
|
| 39 |
-
|
| 40 |
-
def forward(self, x):
|
| 41 |
-
x = self.conv1(x)
|
| 42 |
-
x = self.act(x)
|
| 43 |
-
x = self.conv2(x)
|
| 44 |
-
x = self.norm(x)
|
| 45 |
-
|
| 46 |
-
return x
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def compute_upscale_stages(patch_size, min_size=4):
|
| 50 |
-
# Compute how many times to upscale per dimension
|
| 51 |
-
num_stages = []
|
| 52 |
-
for size in patch_size:
|
| 53 |
-
stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
|
| 54 |
-
num_stages.append(stages)
|
| 55 |
-
return num_stages
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class EoMT(nn.Module):
|
| 59 |
-
def __init__(
|
| 60 |
-
self,
|
| 61 |
-
backbone: "VisionTransformer",
|
| 62 |
-
num_classes: int,
|
| 63 |
-
num_q: int,
|
| 64 |
-
num_blocks=4,
|
| 65 |
-
masked_attn_enabled=True,
|
| 66 |
-
):
|
| 67 |
-
super().__init__()
|
| 68 |
-
self.backbone = backbone
|
| 69 |
-
self.num_q = num_q
|
| 70 |
-
self.num_blocks = num_blocks
|
| 71 |
-
self.masked_attn_enabled = masked_attn_enabled
|
| 72 |
-
|
| 73 |
-
self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
|
| 74 |
-
|
| 75 |
-
self.q = nn.Embedding(num_q, self.backbone.embed_dim)
|
| 76 |
-
|
| 77 |
-
self.class_head = nn.Linear(self.backbone.embed_dim, num_classes + 1)
|
| 78 |
-
|
| 79 |
-
self.mask_head = nn.Sequential(
|
| 80 |
-
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 81 |
-
nn.GELU(),
|
| 82 |
-
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 83 |
-
nn.GELU(),
|
| 84 |
-
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
patch_size = self.backbone.patch_embed.patch_size
|
| 88 |
-
num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
|
| 89 |
-
|
| 90 |
-
# Build per-stage scale factors list
|
| 91 |
-
max_stages = max(num_upscale_stages)
|
| 92 |
-
upscale_blocks = []
|
| 93 |
-
for stage_idx in range(max_stages):
|
| 94 |
-
# for each dimension, upscale by 2 only if this dimension still has
|
| 95 |
-
# remaining upscales at this stage
|
| 96 |
-
scale_factors = tuple(
|
| 97 |
-
2 if stage_idx < num_upscale_stages[dim] else 1
|
| 98 |
-
for dim in range(len(patch_size))
|
| 99 |
-
)
|
| 100 |
-
upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
|
| 101 |
-
scale_factors=scale_factors))
|
| 102 |
-
|
| 103 |
-
self.upscale = nn.Sequential(*upscale_blocks)
|
| 104 |
-
|
| 105 |
-
def _predict(self, x: torch.Tensor):
|
| 106 |
-
q = x[:, : self.num_q, :]
|
| 107 |
-
|
| 108 |
-
class_logits = self.class_head(q)
|
| 109 |
-
|
| 110 |
-
x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
|
| 111 |
-
x = x.transpose(1, 2).reshape(
|
| 112 |
-
x.shape[0], -1, *self.backbone.patch_embed.grid_size
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
mask_logits = torch.einsum(
|
| 116 |
-
"bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
return mask_logits, class_logits
|
| 120 |
-
|
| 121 |
-
@torch.compiler.disable
|
| 122 |
-
def _disable_attn_mask(self, attn_mask, prob):
|
| 123 |
-
if prob < 1:
|
| 124 |
-
random_queries = (
|
| 125 |
-
torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
|
| 126 |
-
> prob
|
| 127 |
-
)
|
| 128 |
-
attn_mask[
|
| 129 |
-
:, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
|
| 130 |
-
][random_queries] = True
|
| 131 |
-
|
| 132 |
-
return attn_mask
|
| 133 |
-
|
| 134 |
-
def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope=None):
|
| 135 |
-
B, N, C = x.shape
|
| 136 |
-
|
| 137 |
-
q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
|
| 138 |
-
kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
|
| 139 |
-
k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 140 |
-
q, k = module.q_norm(q), module.k_norm(k)
|
| 141 |
-
|
| 142 |
-
if mask is not None:
|
| 143 |
-
mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
|
| 144 |
-
|
| 145 |
-
dropout_p = module.attn_drop.p if self.training else 0.0
|
| 146 |
-
|
| 147 |
-
if rope is not None:
|
| 148 |
-
if isinstance(rope, list):
|
| 149 |
-
rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
|
| 150 |
-
q, k = module.apply_rotary_pos_emb(q, k, rope)
|
| 151 |
-
|
| 152 |
-
if module.fused_attn:
|
| 153 |
-
x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
|
| 154 |
-
else:
|
| 155 |
-
attn = (q @ k.transpose(-2, -1)) * module.scale
|
| 156 |
-
if mask is not None:
|
| 157 |
-
attn = attn.masked_fill(~mask, float("-inf"))
|
| 158 |
-
attn = F.softmax(attn, dim=-1)
|
| 159 |
-
attn = module.attn_drop(attn)
|
| 160 |
-
x = attn @ v
|
| 161 |
-
|
| 162 |
-
x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
|
| 163 |
-
|
| 164 |
-
return x
|
| 165 |
-
|
| 166 |
-
def forward(self, x: torch.Tensor):
|
| 167 |
-
x = self.backbone.patch_embed(x)
|
| 168 |
-
x, rope = self.backbone._pos_embed(x)
|
| 169 |
-
x = self.backbone.patch_drop(x)
|
| 170 |
-
x = self.backbone.norm_pre(x)
|
| 171 |
-
|
| 172 |
-
attn_mask = None
|
| 173 |
-
mask_logits_per_layer, class_logits_per_layer = [], []
|
| 174 |
-
|
| 175 |
-
for i, block in enumerate(self.backbone.blocks):
|
| 176 |
-
if i == len(self.backbone.blocks) - self.num_blocks:
|
| 177 |
-
x = torch.cat(
|
| 178 |
-
(self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
|
| 179 |
-
)
|
| 180 |
-
|
| 181 |
-
if (
|
| 182 |
-
self.masked_attn_enabled
|
| 183 |
-
and i >= len(self.backbone.blocks) - self.num_blocks
|
| 184 |
-
):
|
| 185 |
-
mask_logits, class_logits = self._predict(self.backbone.norm(x))
|
| 186 |
-
mask_logits_per_layer.append(mask_logits)
|
| 187 |
-
class_logits_per_layer.append(class_logits)
|
| 188 |
-
|
| 189 |
-
attn_mask = torch.ones(
|
| 190 |
-
x.shape[0],
|
| 191 |
-
x.shape[1],
|
| 192 |
-
x.shape[1],
|
| 193 |
-
dtype=torch.bool,
|
| 194 |
-
device=x.device,
|
| 195 |
-
)
|
| 196 |
-
interpolated = F.interpolate(
|
| 197 |
-
mask_logits,
|
| 198 |
-
self.backbone.patch_embed.grid_size,
|
| 199 |
-
mode="trilinear",
|
| 200 |
-
)
|
| 201 |
-
interpolated = interpolated.view(
|
| 202 |
-
interpolated.size(0), interpolated.size(1), -1
|
| 203 |
-
)
|
| 204 |
-
attn_mask[
|
| 205 |
-
:,
|
| 206 |
-
: self.num_q,
|
| 207 |
-
self.num_q + self.backbone.num_prefix_tokens :,
|
| 208 |
-
] = (
|
| 209 |
-
interpolated > 0
|
| 210 |
-
)
|
| 211 |
-
attn_mask = self._disable_attn_mask(
|
| 212 |
-
attn_mask,
|
| 213 |
-
self.attn_mask_probs[
|
| 214 |
-
i - len(self.backbone.blocks) + self.num_blocks
|
| 215 |
-
],
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
x = x + block.drop_path1(
|
| 219 |
-
block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope))
|
| 220 |
-
)
|
| 221 |
-
x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
|
| 222 |
-
|
| 223 |
-
mask_logits, class_logits = self._predict(self.backbone.norm(x))
|
| 224 |
-
mask_logits_per_layer.append(mask_logits)
|
| 225 |
-
class_logits_per_layer.append(class_logits)
|
| 226 |
-
|
| 227 |
-
return (
|
| 228 |
-
mask_logits_per_layer,
|
| 229 |
-
class_logits_per_layer,
|
| 230 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/models/resnet.py
DELETED
|
@@ -1,726 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import math
|
| 3 |
-
from urllib.parse import urlparse
|
| 4 |
-
from typing import Type, Any, Tuple, List, Optional, Union, Dict
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from spectre.utils import to_ntuple
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
|
| 14 |
-
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
| 15 |
-
return padding
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class BasicBlock(nn.Module):
|
| 19 |
-
expansion = 1
|
| 20 |
-
|
| 21 |
-
def __init__(
|
| 22 |
-
self,
|
| 23 |
-
inplanes: int,
|
| 24 |
-
planes: int,
|
| 25 |
-
stride: int = 1,
|
| 26 |
-
downsample: Optional[nn.Module] = None,
|
| 27 |
-
cardinality: int = 1,
|
| 28 |
-
base_width: int = 64,
|
| 29 |
-
reduce_first: int = 1,
|
| 30 |
-
dilation: int = 1,
|
| 31 |
-
first_dilation: Optional[int] = None,
|
| 32 |
-
act_layer: Type[nn.Module] = nn.ReLU,
|
| 33 |
-
norm_layer: Type[nn.Module] = nn.BatchNorm3d,
|
| 34 |
-
):
|
| 35 |
-
"""
|
| 36 |
-
Args:
|
| 37 |
-
inplanes: Input channel dimensionality.
|
| 38 |
-
planes: Used to determine output channel dimensionalities.
|
| 39 |
-
stride: Stride used in convolution layers.
|
| 40 |
-
downsample: Optional downsample layer for residual path.
|
| 41 |
-
cardinality: Number of convolution groups.
|
| 42 |
-
base_width: Base width used to determine output channel dimensionality.
|
| 43 |
-
reduce_first: Reduction factor for first convolution output width of residual blocks.
|
| 44 |
-
dilation: Dilation rate for convolution layers.
|
| 45 |
-
first_dilation: Dilation rate for first convolution layer.
|
| 46 |
-
act_layer: Activation layer.
|
| 47 |
-
norm_layer: Normalization layer.
|
| 48 |
-
"""
|
| 49 |
-
super(BasicBlock, self).__init__()
|
| 50 |
-
|
| 51 |
-
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
|
| 52 |
-
assert base_width == 64, 'BasicBlock does not support changing base width'
|
| 53 |
-
first_planes = planes // reduce_first
|
| 54 |
-
outplanes = planes * self.expansion
|
| 55 |
-
first_dilation = first_dilation or dilation
|
| 56 |
-
|
| 57 |
-
self.conv1 = nn.Conv3d(
|
| 58 |
-
inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
|
| 59 |
-
dilation=first_dilation, bias=False)
|
| 60 |
-
self.bn1 = norm_layer(first_planes)
|
| 61 |
-
self.act1 = act_layer()
|
| 62 |
-
|
| 63 |
-
self.conv2 = nn.Conv3d(
|
| 64 |
-
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
| 65 |
-
self.bn2 = norm_layer(outplanes)
|
| 66 |
-
|
| 67 |
-
self.act2 = act_layer()
|
| 68 |
-
self.downsample = downsample
|
| 69 |
-
self.stride = stride
|
| 70 |
-
self.dilation = dilation
|
| 71 |
-
|
| 72 |
-
def zero_init_last(self):
|
| 73 |
-
if getattr(self.bn2, 'weight', None) is not None:
|
| 74 |
-
nn.init.zeros_(self.bn2.weight)
|
| 75 |
-
|
| 76 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
-
shortcut = x
|
| 78 |
-
|
| 79 |
-
x = self.conv1(x)
|
| 80 |
-
x = self.bn1(x)
|
| 81 |
-
x = self.act1(x)
|
| 82 |
-
|
| 83 |
-
x = self.conv2(x)
|
| 84 |
-
x = self.bn2(x)
|
| 85 |
-
|
| 86 |
-
if self.downsample is not None:
|
| 87 |
-
shortcut = self.downsample(shortcut)
|
| 88 |
-
x = x + shortcut
|
| 89 |
-
x = self.act2(x)
|
| 90 |
-
|
| 91 |
-
return x
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class Bottleneck(nn.Module):
|
| 95 |
-
expansion = 4
|
| 96 |
-
|
| 97 |
-
def __init__(
|
| 98 |
-
self,
|
| 99 |
-
inplanes: int,
|
| 100 |
-
planes: int,
|
| 101 |
-
stride: int = 1,
|
| 102 |
-
downsample: Optional[nn.Module] = None,
|
| 103 |
-
cardinality: int = 1,
|
| 104 |
-
base_width: int = 64,
|
| 105 |
-
reduce_first: int = 1,
|
| 106 |
-
dilation: int = 1,
|
| 107 |
-
first_dilation: Optional[int] = None,
|
| 108 |
-
act_layer: Type[nn.Module] = nn.ReLU,
|
| 109 |
-
norm_layer: Type[nn.Module] = nn.BatchNorm3d,
|
| 110 |
-
):
|
| 111 |
-
"""
|
| 112 |
-
Args:
|
| 113 |
-
inplanes: Input channel dimensionality.
|
| 114 |
-
planes: Used to determine output channel dimensionalities.
|
| 115 |
-
stride: Stride used in convolution layers.
|
| 116 |
-
downsample: Optional downsample layer for residual path.
|
| 117 |
-
cardinality: Number of convolution groups.
|
| 118 |
-
base_width: Base width used to determine output channel dimensionality.
|
| 119 |
-
reduce_first: Reduction factor for first convolution output width of residual blocks.
|
| 120 |
-
dilation: Dilation rate for convolution layers.
|
| 121 |
-
first_dilation: Dilation rate for first convolution layer.
|
| 122 |
-
act_layer: Activation layer.
|
| 123 |
-
norm_layer: Normalization layer.
|
| 124 |
-
"""
|
| 125 |
-
super(Bottleneck, self).__init__()
|
| 126 |
-
|
| 127 |
-
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
| 128 |
-
first_planes = width // reduce_first
|
| 129 |
-
outplanes = planes * self.expansion
|
| 130 |
-
first_dilation = first_dilation or dilation
|
| 131 |
-
|
| 132 |
-
self.conv1 = nn.Conv3d(inplanes, first_planes, kernel_size=1, bias=False)
|
| 133 |
-
self.bn1 = norm_layer(first_planes)
|
| 134 |
-
self.act1 = act_layer()
|
| 135 |
-
|
| 136 |
-
self.conv2 = nn.Conv3d(
|
| 137 |
-
first_planes, width, kernel_size=3, stride=stride,
|
| 138 |
-
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
| 139 |
-
self.bn2 = norm_layer(width)
|
| 140 |
-
self.act2 = act_layer()
|
| 141 |
-
|
| 142 |
-
self.conv3 = nn.Conv3d(width, outplanes, kernel_size=1, bias=False)
|
| 143 |
-
self.bn3 = norm_layer(outplanes)
|
| 144 |
-
|
| 145 |
-
self.act3 = act_layer()
|
| 146 |
-
self.downsample = downsample
|
| 147 |
-
self.stride = stride
|
| 148 |
-
self.dilation = dilation
|
| 149 |
-
|
| 150 |
-
def zero_init_last(self):
|
| 151 |
-
if getattr(self.bn3, 'weight', None) is not None:
|
| 152 |
-
nn.init.zeros_(self.bn3.weight)
|
| 153 |
-
|
| 154 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 155 |
-
shortcut = x
|
| 156 |
-
|
| 157 |
-
x = self.conv1(x)
|
| 158 |
-
x = self.bn1(x)
|
| 159 |
-
x = self.act1(x)
|
| 160 |
-
|
| 161 |
-
x = self.conv2(x)
|
| 162 |
-
x = self.bn2(x)
|
| 163 |
-
x = self.act2(x)
|
| 164 |
-
|
| 165 |
-
x = self.conv3(x)
|
| 166 |
-
x = self.bn3(x)
|
| 167 |
-
|
| 168 |
-
if self.downsample is not None:
|
| 169 |
-
shortcut = self.downsample(shortcut)
|
| 170 |
-
x = x + shortcut
|
| 171 |
-
x = self.act3(x)
|
| 172 |
-
|
| 173 |
-
return x
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def downsample_conv(
|
| 177 |
-
in_channels: int,
|
| 178 |
-
out_channels: int,
|
| 179 |
-
kernel_size: int,
|
| 180 |
-
stride: int = 1,
|
| 181 |
-
dilation: int = 1,
|
| 182 |
-
first_dilation: Optional[int] = None,
|
| 183 |
-
norm_layer: Optional[Type[nn.Module]] = None,
|
| 184 |
-
) -> nn.Module:
|
| 185 |
-
norm_layer = norm_layer or nn.BatchNorm3d
|
| 186 |
-
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
|
| 187 |
-
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
|
| 188 |
-
p = get_padding(kernel_size, stride, first_dilation)
|
| 189 |
-
|
| 190 |
-
return nn.Sequential(*[
|
| 191 |
-
nn.Conv3d(
|
| 192 |
-
in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
|
| 193 |
-
norm_layer(out_channels)
|
| 194 |
-
])
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def downsample_avg(
|
| 198 |
-
in_channels: int,
|
| 199 |
-
out_channels: int,
|
| 200 |
-
kernel_size: int,
|
| 201 |
-
stride: int = 1,
|
| 202 |
-
dilation: int = 1,
|
| 203 |
-
first_dilation: Optional[int] = None,
|
| 204 |
-
norm_layer: Optional[Type[nn.Module]] = None,
|
| 205 |
-
) -> nn.Module:
|
| 206 |
-
norm_layer = norm_layer or nn.BatchNorm3d
|
| 207 |
-
avg_stride = stride if dilation == 1 else 1
|
| 208 |
-
if stride == 1 and dilation == 1:
|
| 209 |
-
pool = nn.Identity()
|
| 210 |
-
else:
|
| 211 |
-
pool = nn.AvgPool3d(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
| 212 |
-
|
| 213 |
-
return nn.Sequential(*[
|
| 214 |
-
pool,
|
| 215 |
-
nn.Conv3d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
|
| 216 |
-
norm_layer(out_channels)
|
| 217 |
-
])
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def make_blocks(
|
| 221 |
-
block_fns: Tuple[Union[BasicBlock, Bottleneck]],
|
| 222 |
-
channels: Tuple[int, ...],
|
| 223 |
-
block_repeats: Tuple[int, ...],
|
| 224 |
-
inplanes: int,
|
| 225 |
-
reduce_first: int = 1,
|
| 226 |
-
output_stride: int = 32,
|
| 227 |
-
down_kernel_size: int = 1,
|
| 228 |
-
avg_down: bool = False,
|
| 229 |
-
**kwargs,
|
| 230 |
-
) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
|
| 231 |
-
stages = []
|
| 232 |
-
feature_info = []
|
| 233 |
-
net_num_blocks = sum(block_repeats)
|
| 234 |
-
net_block_idx = 0
|
| 235 |
-
net_stride = 4
|
| 236 |
-
dilation = prev_dilation = 1
|
| 237 |
-
for stage_idx, (block_fn, planes, num_blocks) in enumerate(zip(block_fns, channels, block_repeats)):
|
| 238 |
-
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
|
| 239 |
-
stride = 1 if stage_idx == 0 else 2
|
| 240 |
-
if net_stride >= output_stride:
|
| 241 |
-
dilation *= stride
|
| 242 |
-
stride = 1
|
| 243 |
-
else:
|
| 244 |
-
net_stride *= stride
|
| 245 |
-
|
| 246 |
-
downsample = None
|
| 247 |
-
if stride != 1 or inplanes != planes * block_fn.expansion:
|
| 248 |
-
down_kwargs = dict(
|
| 249 |
-
in_channels=inplanes,
|
| 250 |
-
out_channels=planes * block_fn.expansion,
|
| 251 |
-
kernel_size=down_kernel_size,
|
| 252 |
-
stride=stride,
|
| 253 |
-
dilation=dilation,
|
| 254 |
-
first_dilation=prev_dilation,
|
| 255 |
-
norm_layer=kwargs.get('norm_layer'),
|
| 256 |
-
)
|
| 257 |
-
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
|
| 258 |
-
|
| 259 |
-
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, **kwargs)
|
| 260 |
-
blocks = []
|
| 261 |
-
for block_idx in range(num_blocks):
|
| 262 |
-
downsample = downsample if block_idx == 0 else None
|
| 263 |
-
stride = stride if block_idx == 0 else 1
|
| 264 |
-
blocks.append(block_fn(
|
| 265 |
-
inplanes,
|
| 266 |
-
planes,
|
| 267 |
-
stride,
|
| 268 |
-
downsample,
|
| 269 |
-
first_dilation=prev_dilation,
|
| 270 |
-
**block_kwargs,
|
| 271 |
-
))
|
| 272 |
-
prev_dilation = dilation
|
| 273 |
-
inplanes = planes * block_fn.expansion
|
| 274 |
-
net_block_idx += 1
|
| 275 |
-
|
| 276 |
-
stages.append((stage_name, nn.Sequential(*blocks)))
|
| 277 |
-
feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
|
| 278 |
-
|
| 279 |
-
return stages, feature_info
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def feature_take_indices(
|
| 283 |
-
num_features: int,
|
| 284 |
-
indices: Optional[Union[int, List[int]]] = None,
|
| 285 |
-
as_set: bool = False,
|
| 286 |
-
) -> Tuple[List[int], int]:
|
| 287 |
-
""" Determine the absolute feature indices to 'take' from.
|
| 288 |
-
|
| 289 |
-
Note: This function can be called in forward() so must be torchscript compatible,
|
| 290 |
-
which requires some incomplete typing and workaround hacks.
|
| 291 |
-
|
| 292 |
-
Args:
|
| 293 |
-
num_features: total number of features to select from
|
| 294 |
-
indices: indices to select,
|
| 295 |
-
None -> select all
|
| 296 |
-
int -> select last n
|
| 297 |
-
list/tuple of int -> return specified (-ve indices specify from end)
|
| 298 |
-
as_set: return as a set
|
| 299 |
-
|
| 300 |
-
Returns:
|
| 301 |
-
List (or set) of absolute (from beginning) indices, Maximum index
|
| 302 |
-
"""
|
| 303 |
-
if indices is None:
|
| 304 |
-
indices = num_features # all features if None
|
| 305 |
-
|
| 306 |
-
if isinstance(indices, int):
|
| 307 |
-
# convert int -> last n indices
|
| 308 |
-
assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})'
|
| 309 |
-
take_indices = [num_features - indices + i for i in range(indices)]
|
| 310 |
-
else:
|
| 311 |
-
take_indices: List[int] = []
|
| 312 |
-
for i in indices:
|
| 313 |
-
idx = num_features + i if i < 0 else i
|
| 314 |
-
assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})'
|
| 315 |
-
take_indices.append(idx)
|
| 316 |
-
|
| 317 |
-
if not torch.jit.is_scripting() and as_set:
|
| 318 |
-
return set(take_indices), max(take_indices)
|
| 319 |
-
|
| 320 |
-
return take_indices, max(take_indices)
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
class ResNet(nn.Module):
|
| 324 |
-
"""ResNet / ResNeXt
|
| 325 |
-
|
| 326 |
-
This class implements all variants of ResNet, ResNeXt that
|
| 327 |
-
* have > 1 stride in the 3x3 conv layer of bottleneck
|
| 328 |
-
* have conv-bn-act ordering
|
| 329 |
-
|
| 330 |
-
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
|
| 331 |
-
variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
|
| 332 |
-
'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
|
| 333 |
-
|
| 334 |
-
ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
|
| 335 |
-
* normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
|
| 336 |
-
* c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
|
| 337 |
-
* d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
|
| 338 |
-
* e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
|
| 339 |
-
* s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
|
| 340 |
-
* t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
|
| 341 |
-
* tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
|
| 342 |
-
|
| 343 |
-
ResNeXt
|
| 344 |
-
* normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
|
| 345 |
-
* same c,d, e, s variants as ResNet can be enabled
|
| 346 |
-
"""
|
| 347 |
-
|
| 348 |
-
def __init__(
|
| 349 |
-
self,
|
| 350 |
-
block: Union[BasicBlock, Bottleneck],
|
| 351 |
-
layers: Tuple[int, ...],
|
| 352 |
-
num_classes: int = 1000,
|
| 353 |
-
in_chans: int = 1,
|
| 354 |
-
output_stride: int = 32,
|
| 355 |
-
global_pool: str = 'avg',
|
| 356 |
-
cardinality: int = 1,
|
| 357 |
-
base_width: int = 64,
|
| 358 |
-
stem_width: int = 64,
|
| 359 |
-
stem_type: str = '',
|
| 360 |
-
replace_stem_pool: bool = False,
|
| 361 |
-
block_reduce_first: int = 1,
|
| 362 |
-
down_kernel_size: int = 1,
|
| 363 |
-
avg_down: bool = False,
|
| 364 |
-
channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
|
| 365 |
-
act_layer: Type[nn.Module] = nn.ReLU,
|
| 366 |
-
norm_layer: Type[nn.Module] = nn.BatchNorm3d,
|
| 367 |
-
drop_rate: float = 0.0,
|
| 368 |
-
zero_init_last: bool = True,
|
| 369 |
-
block_args: Optional[Dict[str, Any]] = None,
|
| 370 |
-
):
|
| 371 |
-
"""
|
| 372 |
-
Args:
|
| 373 |
-
block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
|
| 374 |
-
layers (List[int]) : number of layers in each block
|
| 375 |
-
num_classes (int): number of classification classes (default 1000)
|
| 376 |
-
in_chans (int): number of input (color) channels. (default 3)
|
| 377 |
-
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
|
| 378 |
-
global_pool (str): Global pooling type. One of 'avg', 'max' (default 'avg')
|
| 379 |
-
cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
|
| 380 |
-
base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
|
| 381 |
-
stem_width (int): number of channels in stem convolutions (default 64)
|
| 382 |
-
stem_type (str): The type of stem (default ''):
|
| 383 |
-
* '', default - a single 7x7 conv with a width of stem_width
|
| 384 |
-
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
|
| 385 |
-
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
|
| 386 |
-
replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
|
| 387 |
-
block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
|
| 388 |
-
1 for all archs except senets, where 2 (default 1)
|
| 389 |
-
down_kernel_size (int): kernel size of residual block downsample path,
|
| 390 |
-
1x1 for most, 3x3 for senets (default: 1)
|
| 391 |
-
avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
|
| 392 |
-
act_layer (str, nn.Module): activation layer
|
| 393 |
-
norm_layer (str, nn.Module): normalization layer
|
| 394 |
-
drop_rate (float): Dropout probability before classifier, for training (default 0.)
|
| 395 |
-
zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
|
| 396 |
-
block_args (dict): Extra kwargs to pass through to block module
|
| 397 |
-
"""
|
| 398 |
-
super(ResNet, self).__init__()
|
| 399 |
-
block_args = block_args or dict()
|
| 400 |
-
assert output_stride in (8, 16, 32)
|
| 401 |
-
self.num_classes = num_classes
|
| 402 |
-
self.drop_rate = drop_rate
|
| 403 |
-
|
| 404 |
-
# Stem
|
| 405 |
-
deep_stem = 'deep' in stem_type
|
| 406 |
-
inplanes = stem_width * 2 if deep_stem else 64
|
| 407 |
-
if deep_stem:
|
| 408 |
-
stem_chs = (stem_width, stem_width)
|
| 409 |
-
if 'tiered' in stem_type:
|
| 410 |
-
stem_chs = (3 * (stem_width // 4), stem_width)
|
| 411 |
-
self.conv1 = nn.Sequential(*[
|
| 412 |
-
nn.Conv3d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
|
| 413 |
-
norm_layer(stem_chs[0]),
|
| 414 |
-
act_layer(),
|
| 415 |
-
nn.Conv3d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
|
| 416 |
-
norm_layer(stem_chs[1]),
|
| 417 |
-
act_layer(),
|
| 418 |
-
nn.Conv3d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
|
| 419 |
-
else:
|
| 420 |
-
self.conv1 = nn.Conv3d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
| 421 |
-
self.bn1 = norm_layer(inplanes)
|
| 422 |
-
self.act1 = act_layer()
|
| 423 |
-
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
|
| 424 |
-
|
| 425 |
-
# Stem pooling. The name 'maxpool' remains for weight compatibility.
|
| 426 |
-
if replace_stem_pool:
|
| 427 |
-
self.maxpool = nn.Sequential(*filter(None, [
|
| 428 |
-
nn.Conv3d(inplanes, inplanes, 3, stride=2, padding=1, bias=False),
|
| 429 |
-
norm_layer(inplanes),
|
| 430 |
-
act_layer(),
|
| 431 |
-
]))
|
| 432 |
-
else:
|
| 433 |
-
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
|
| 434 |
-
|
| 435 |
-
# Feature Blocks
|
| 436 |
-
block_fns = to_ntuple(len(channels))(block)
|
| 437 |
-
stage_modules, stage_feature_info = make_blocks(
|
| 438 |
-
block_fns,
|
| 439 |
-
channels,
|
| 440 |
-
layers,
|
| 441 |
-
inplanes,
|
| 442 |
-
cardinality=cardinality,
|
| 443 |
-
base_width=base_width,
|
| 444 |
-
output_stride=output_stride,
|
| 445 |
-
reduce_first=block_reduce_first,
|
| 446 |
-
avg_down=avg_down,
|
| 447 |
-
down_kernel_size=down_kernel_size,
|
| 448 |
-
act_layer=act_layer,
|
| 449 |
-
norm_layer=norm_layer,
|
| 450 |
-
**block_args,
|
| 451 |
-
)
|
| 452 |
-
for stage in stage_modules:
|
| 453 |
-
self.add_module(*stage) # layer1, layer2, etc
|
| 454 |
-
self.feature_info.extend(stage_feature_info)
|
| 455 |
-
|
| 456 |
-
# Head (Pooling and Classifier)
|
| 457 |
-
self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
|
| 458 |
-
if global_pool == 'avg':
|
| 459 |
-
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
| 460 |
-
elif global_pool == 'max':
|
| 461 |
-
self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
|
| 462 |
-
else:
|
| 463 |
-
raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
|
| 464 |
-
|
| 465 |
-
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 466 |
-
|
| 467 |
-
self.init_weights(zero_init_last=zero_init_last)
|
| 468 |
-
|
| 469 |
-
@torch.jit.ignore
|
| 470 |
-
def init_weights(self, zero_init_last: bool = True):
|
| 471 |
-
for n, m in self.named_modules():
|
| 472 |
-
if isinstance(m, nn.Conv3d):
|
| 473 |
-
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 474 |
-
if zero_init_last:
|
| 475 |
-
for m in self.modules():
|
| 476 |
-
if hasattr(m, 'zero_init_last'):
|
| 477 |
-
m.zero_init_last()
|
| 478 |
-
|
| 479 |
-
@torch.jit.ignore
|
| 480 |
-
def group_matcher(self, coarse: bool = False):
|
| 481 |
-
matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
|
| 482 |
-
return matcher
|
| 483 |
-
|
| 484 |
-
@torch.jit.ignore
|
| 485 |
-
def get_classifier(self, name_only: bool = False):
|
| 486 |
-
return 'fc' if name_only else self.fc
|
| 487 |
-
|
| 488 |
-
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
| 489 |
-
self.num_classes = num_classes
|
| 490 |
-
if global_pool == 'avg':
|
| 491 |
-
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
| 492 |
-
elif global_pool == 'max':
|
| 493 |
-
self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
|
| 494 |
-
else:
|
| 495 |
-
raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
|
| 496 |
-
|
| 497 |
-
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 498 |
-
|
| 499 |
-
def forward_intermediates(
|
| 500 |
-
self,
|
| 501 |
-
x: torch.Tensor,
|
| 502 |
-
indices: Optional[Union[int, List[int]]] = None,
|
| 503 |
-
norm: bool = False,
|
| 504 |
-
stop_early: bool = False,
|
| 505 |
-
output_fmt: str = 'NCHWD',
|
| 506 |
-
intermediates_only: bool = False,
|
| 507 |
-
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
| 508 |
-
""" Forward features that returns intermediates.
|
| 509 |
-
|
| 510 |
-
Args:
|
| 511 |
-
x: Input image tensor
|
| 512 |
-
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
| 513 |
-
norm: Apply norm layer to compatible intermediates
|
| 514 |
-
stop_early: Stop iterating over blocks when last desired intermediate hit
|
| 515 |
-
output_fmt: Shape of intermediate feature outputs
|
| 516 |
-
intermediates_only: Only return intermediate features
|
| 517 |
-
Returns:
|
| 518 |
-
|
| 519 |
-
"""
|
| 520 |
-
assert output_fmt in ('NCHWD',), 'Output shape must be NCHWD.'
|
| 521 |
-
intermediates = []
|
| 522 |
-
take_indices, max_index = feature_take_indices(5, indices)
|
| 523 |
-
|
| 524 |
-
# forward pass
|
| 525 |
-
feat_idx = 0
|
| 526 |
-
x = self.conv1(x)
|
| 527 |
-
x = self.bn1(x)
|
| 528 |
-
x = self.act1(x)
|
| 529 |
-
if feat_idx in take_indices:
|
| 530 |
-
intermediates.append(x)
|
| 531 |
-
x = self.maxpool(x)
|
| 532 |
-
|
| 533 |
-
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
|
| 534 |
-
if stop_early:
|
| 535 |
-
layer_names = layer_names[:max_index]
|
| 536 |
-
for n in layer_names:
|
| 537 |
-
feat_idx += 1
|
| 538 |
-
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
|
| 539 |
-
if feat_idx in take_indices:
|
| 540 |
-
intermediates.append(x)
|
| 541 |
-
|
| 542 |
-
if intermediates_only:
|
| 543 |
-
return intermediates
|
| 544 |
-
|
| 545 |
-
return x, intermediates
|
| 546 |
-
|
| 547 |
-
def prune_intermediate_layers(
|
| 548 |
-
self,
|
| 549 |
-
indices: Union[int, List[int]] = 1,
|
| 550 |
-
prune_norm: bool = False,
|
| 551 |
-
prune_head: bool = True,
|
| 552 |
-
):
|
| 553 |
-
""" Prune layers not required for specified intermediates.
|
| 554 |
-
"""
|
| 555 |
-
take_indices, max_index = feature_take_indices(5, indices)
|
| 556 |
-
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
|
| 557 |
-
layer_names = layer_names[max_index:]
|
| 558 |
-
for n in layer_names:
|
| 559 |
-
setattr(self, n, nn.Identity())
|
| 560 |
-
if prune_head:
|
| 561 |
-
self.reset_classifier(0, '')
|
| 562 |
-
return take_indices
|
| 563 |
-
|
| 564 |
-
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 565 |
-
x = self.conv1(x)
|
| 566 |
-
x = self.bn1(x)
|
| 567 |
-
x = self.act1(x)
|
| 568 |
-
x = self.maxpool(x)
|
| 569 |
-
|
| 570 |
-
x = self.layer1(x)
|
| 571 |
-
x = self.layer2(x)
|
| 572 |
-
x = self.layer3(x)
|
| 573 |
-
x = self.layer4(x)
|
| 574 |
-
return x
|
| 575 |
-
|
| 576 |
-
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
| 577 |
-
x = self.global_pool(x)
|
| 578 |
-
x = x.flatten(1)
|
| 579 |
-
if self.drop_rate:
|
| 580 |
-
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
| 581 |
-
return x if pre_logits else self.fc(x)
|
| 582 |
-
|
| 583 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 584 |
-
x = self.forward_features(x)
|
| 585 |
-
x = self.forward_head(x)
|
| 586 |
-
return x
|
| 587 |
-
|
| 588 |
-
@classmethod
|
| 589 |
-
def from_pretrained(
|
| 590 |
-
cls,
|
| 591 |
-
checkpoint_path_or_url: Union[str, os.PathLike],
|
| 592 |
-
verbose: bool = True,
|
| 593 |
-
**kwargs
|
| 594 |
-
) -> 'ResNet':
|
| 595 |
-
"""Load pretrained model weights from a local path or a URL."""
|
| 596 |
-
model = cls(**kwargs)
|
| 597 |
-
|
| 598 |
-
def _is_url(path: str) -> bool:
|
| 599 |
-
try:
|
| 600 |
-
parsed = urlparse(str(path))
|
| 601 |
-
return parsed.scheme in ('http', 'https')
|
| 602 |
-
except Exception:
|
| 603 |
-
return False
|
| 604 |
-
|
| 605 |
-
if _is_url(checkpoint_path_or_url):
|
| 606 |
-
if verbose:
|
| 607 |
-
print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
|
| 608 |
-
state_dict = torch.hub.load_state_dict_from_url(
|
| 609 |
-
checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
|
| 610 |
-
else:
|
| 611 |
-
local_path = os.fspath(checkpoint_path_or_url)
|
| 612 |
-
if not os.path.exists(local_path):
|
| 613 |
-
raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
|
| 614 |
-
if verbose:
|
| 615 |
-
print(f"Loading checkpoint from local path: {local_path}")
|
| 616 |
-
state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
|
| 617 |
-
|
| 618 |
-
msg = model.load_state_dict(state_dict, strict=False)
|
| 619 |
-
if verbose:
|
| 620 |
-
print(f"Loaded pretrained weights with msg: {msg}")
|
| 621 |
-
return model
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
def resnet18(
|
| 626 |
-
checkpoint_path_or_url: Optional[str] = None,
|
| 627 |
-
**kwargs
|
| 628 |
-
) -> ResNet:
|
| 629 |
-
"""ResNet-18 model with 3D operations.
|
| 630 |
-
"""
|
| 631 |
-
kwargs = dict(
|
| 632 |
-
block=BasicBlock,
|
| 633 |
-
layers=[2, 2, 2, 2],
|
| 634 |
-
cardinality=1,
|
| 635 |
-
**kwargs,
|
| 636 |
-
)
|
| 637 |
-
if checkpoint_path_or_url:
|
| 638 |
-
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 639 |
-
return ResNet(**kwargs)
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
def resnet34(
|
| 643 |
-
checkpoint_path_or_url: Optional[str] = None,
|
| 644 |
-
**kwargs
|
| 645 |
-
) -> ResNet:
|
| 646 |
-
"""ResNet-34 model with 3D operations.
|
| 647 |
-
"""
|
| 648 |
-
kwargs = dict(
|
| 649 |
-
block=BasicBlock,
|
| 650 |
-
layers=[3, 4, 6, 3],
|
| 651 |
-
cardinality=1,
|
| 652 |
-
**kwargs,
|
| 653 |
-
)
|
| 654 |
-
if checkpoint_path_or_url:
|
| 655 |
-
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 656 |
-
return ResNet(**kwargs)
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
def resnet50(
|
| 660 |
-
checkpoint_path_or_url: Optional[str] = None,
|
| 661 |
-
**kwargs
|
| 662 |
-
) -> ResNet:
|
| 663 |
-
"""ResNet-50 model with 3D operations.
|
| 664 |
-
"""
|
| 665 |
-
kwargs = dict(
|
| 666 |
-
block=Bottleneck,
|
| 667 |
-
layers=[3, 4, 6, 3],
|
| 668 |
-
cardinality=1,
|
| 669 |
-
**kwargs,
|
| 670 |
-
)
|
| 671 |
-
if checkpoint_path_or_url:
|
| 672 |
-
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 673 |
-
return ResNet(**kwargs)
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
def resnet101(
|
| 677 |
-
checkpoint_path_or_url: Optional[str] = None,
|
| 678 |
-
**kwargs
|
| 679 |
-
) -> ResNet:
|
| 680 |
-
"""ResNet-101 model with 3D operations.
|
| 681 |
-
"""
|
| 682 |
-
kwargs = dict(
|
| 683 |
-
block=Bottleneck,
|
| 684 |
-
layers=[3, 4, 23, 3],
|
| 685 |
-
cardinality=1,
|
| 686 |
-
**kwargs,
|
| 687 |
-
)
|
| 688 |
-
if checkpoint_path_or_url:
|
| 689 |
-
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 690 |
-
return ResNet(**kwargs)
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
def resnext50(
|
| 694 |
-
checkpoint_path_or_url: Optional[str] = None,
|
| 695 |
-
**kwargs
|
| 696 |
-
) -> ResNet:
|
| 697 |
-
"""ResNeXt-50 model with 3D operations.
|
| 698 |
-
"""
|
| 699 |
-
kwargs = dict(
|
| 700 |
-
block=Bottleneck,
|
| 701 |
-
layers=[3, 4, 6, 3],
|
| 702 |
-
cardinality=32,
|
| 703 |
-
base_width=4,
|
| 704 |
-
**kwargs,
|
| 705 |
-
)
|
| 706 |
-
if checkpoint_path_or_url:
|
| 707 |
-
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 708 |
-
return ResNet(**kwargs)
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
def resnext101(
|
| 712 |
-
checkpoint_path_or_url: Optional[str] = None,
|
| 713 |
-
**kwargs
|
| 714 |
-
) -> ResNet:
|
| 715 |
-
"""ResNeXt-101 model with 3D operations.
|
| 716 |
-
"""
|
| 717 |
-
kwargs = dict(
|
| 718 |
-
block=Bottleneck,
|
| 719 |
-
layers=[3, 4, 23, 3],
|
| 720 |
-
cardinality=32,
|
| 721 |
-
base_width=8,
|
| 722 |
-
**kwargs,
|
| 723 |
-
)
|
| 724 |
-
if checkpoint_path_or_url:
|
| 725 |
-
return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
|
| 726 |
-
return ResNet(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/models/seomt.py
DELETED
|
@@ -1,394 +0,0 @@
|
|
| 1 |
-
# Adapted from https://github.com/tue-mps/eomt/
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
from typing import Optional, Tuple, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from spectre.models import VisionTransformer
|
| 11 |
-
from spectre.models.layers import LayerNorm3d
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class ScaleBlock(nn.Module):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
embed_dim: int,
|
| 18 |
-
scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
|
| 19 |
-
conv1_layer: nn.Module = nn.ConvTranspose3d,
|
| 20 |
-
):
|
| 21 |
-
super().__init__()
|
| 22 |
-
|
| 23 |
-
self.conv1 = conv1_layer(
|
| 24 |
-
embed_dim,
|
| 25 |
-
embed_dim,
|
| 26 |
-
kernel_size=scale_factors,
|
| 27 |
-
stride=scale_factors,
|
| 28 |
-
)
|
| 29 |
-
self.act = nn.GELU()
|
| 30 |
-
self.conv2 = nn.Conv3d(
|
| 31 |
-
embed_dim,
|
| 32 |
-
embed_dim,
|
| 33 |
-
kernel_size=3,
|
| 34 |
-
padding=1,
|
| 35 |
-
groups=embed_dim,
|
| 36 |
-
bias=False,
|
| 37 |
-
)
|
| 38 |
-
self.norm = LayerNorm3d(embed_dim)
|
| 39 |
-
|
| 40 |
-
def forward(self, x):
|
| 41 |
-
# print(x.shape)
|
| 42 |
-
x = self.conv1(x)
|
| 43 |
-
x = self.act(x)
|
| 44 |
-
x = self.conv2(x)
|
| 45 |
-
x = self.norm(x)
|
| 46 |
-
|
| 47 |
-
return x
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def compute_upscale_stages(patch_size, min_size=4):
|
| 51 |
-
# Compute how many times to upscale per dimension
|
| 52 |
-
num_stages = []
|
| 53 |
-
for size in patch_size:
|
| 54 |
-
stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
|
| 55 |
-
num_stages.append(stages)
|
| 56 |
-
return num_stages
|
| 57 |
-
|
| 58 |
-
def voxel_shuffle_3d(x: torch.Tensor, r: int = 2) -> torch.Tensor:
|
| 59 |
-
"""
|
| 60 |
-
Rearranges channels of a 5D tensor (N, C*r^3, D, H, W) to
|
| 61 |
-
(N, C, D*r, H*r, W*r).
|
| 62 |
-
"""
|
| 63 |
-
n, c, d, h, w = x.size()
|
| 64 |
-
assert c % (r ** 3) == 0, f"Channels {c} not divisible by r^3={r**3}"
|
| 65 |
-
out_c = c // (r ** 3)
|
| 66 |
-
x = x.view(n, out_c, r, r, r, d, h, w) # (N, C, r, r, r, D, H, W)
|
| 67 |
-
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() # (N, C, D, r, H, r, W, r)
|
| 68 |
-
x = x.view(n, out_c, d * r, h * r, w * r) # (N, C, D*r, H*r, W*r)
|
| 69 |
-
return x
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class MLPUpBlock3D(nn.Module):
|
| 73 |
-
"""
|
| 74 |
-
2x upsampling via:
|
| 75 |
-
- 1x1x1 Conv (per-voxel MLP) expanding channels by 2^3
|
| 76 |
-
- 3D voxel shuffle to double D/H/W
|
| 77 |
-
- optional norm + activation
|
| 78 |
-
"""
|
| 79 |
-
def __init__(self, channels: int, norm=nn.Identity, activation=nn.GELU):
|
| 80 |
-
super().__init__()
|
| 81 |
-
self.proj = nn.Conv3d(channels, channels * 8, kernel_size=1, bias=True)
|
| 82 |
-
self.norm = norm(channels) if norm is not None else nn.Identity()
|
| 83 |
-
self.act = activation() if activation is not None else nn.Identity()
|
| 84 |
-
|
| 85 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 86 |
-
x = self.proj(x) # (N, C*8, D, H, W)
|
| 87 |
-
x = voxel_shuffle_3d(x, 2) # (N, C, 2D, 2H, 2W)
|
| 88 |
-
x = self.norm(x)
|
| 89 |
-
x = self.act(x)
|
| 90 |
-
return x
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
class SmolMLPDecoder3D(nn.Module):
|
| 94 |
-
"""
|
| 95 |
-
Simple decoder with two MLP upsampling layers (2x each) for total 4x upsampling.
|
| 96 |
-
"""
|
| 97 |
-
def __init__(
|
| 98 |
-
self,
|
| 99 |
-
in_channels: int,
|
| 100 |
-
out_channels: int,
|
| 101 |
-
norm=LayerNorm3d,
|
| 102 |
-
activation=nn.GELU,
|
| 103 |
-
):
|
| 104 |
-
super().__init__()
|
| 105 |
-
self.up1 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
|
| 106 |
-
self.up2 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
|
| 107 |
-
self.head = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True)
|
| 108 |
-
|
| 109 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
-
x = self.up1(x) # 2x
|
| 111 |
-
x = self.up2(x) # another 2x => 4x total
|
| 112 |
-
x = self.head(x) # map to desired output channels
|
| 113 |
-
return x
|
| 114 |
-
|
| 115 |
-
class SpatialMLPDecoder3D(nn.Module):
|
| 116 |
-
"""
|
| 117 |
-
Per-voxel MLP that expands logits from (B,Q,H,W,D) to (B,Q,H*s,W*s,D*s)
|
| 118 |
-
by predicting a learned s^3 block per voxel, then rearranging spatially.
|
| 119 |
-
"""
|
| 120 |
-
def __init__(self, num_classes: int, upscale_factor: int = 4, hidden_mul: int = 4):
|
| 121 |
-
super().__init__()
|
| 122 |
-
self.num_classes = num_classes
|
| 123 |
-
self.s = upscale_factor
|
| 124 |
-
hidden = hidden_mul * num_classes
|
| 125 |
-
self.mlp = nn.Sequential(
|
| 126 |
-
nn.Linear(num_classes, hidden),
|
| 127 |
-
nn.GELU(),
|
| 128 |
-
nn.Linear(hidden, num_classes * (upscale_factor ** 3)),
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 132 |
-
# x: (B, Q, H, W, D)
|
| 133 |
-
B, Q, H, W, D = x.shape
|
| 134 |
-
s = self.s
|
| 135 |
-
x = x.permute(0, 2, 3, 4, 1).contiguous().view(B * H * W * D, Q) # (BHW D, Q)
|
| 136 |
-
x = self.mlp(x) # (BHW D, Q*s^3)
|
| 137 |
-
x = x.view(B, H, W, D, Q, s, s, s) # (B,H,W,D,Q,s,s,s)
|
| 138 |
-
x = x.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous() # (B,Q,H,s,W,s,D,s)
|
| 139 |
-
x = x.view(B, Q, H * s, W * s, D * s)
|
| 140 |
-
return x
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
class SimpleConvDecoder3D(nn.Module):
|
| 144 |
-
"""
|
| 145 |
-
Depthwise ConvTranspose3d upsampler:
|
| 146 |
-
- Assumes input & output channels == num_classes
|
| 147 |
-
- Uses groups=num_classes for class-wise deconvolution
|
| 148 |
-
"""
|
| 149 |
-
def __init__(self, num_classes: int, upscale_factor: int = 4):
|
| 150 |
-
super().__init__()
|
| 151 |
-
s = upscale_factor
|
| 152 |
-
self.deconv = nn.ConvTranspose3d(
|
| 153 |
-
in_channels=num_classes,
|
| 154 |
-
out_channels=num_classes,
|
| 155 |
-
kernel_size=s,
|
| 156 |
-
stride=s,
|
| 157 |
-
padding=0,
|
| 158 |
-
output_padding=0,
|
| 159 |
-
groups=num_classes,
|
| 160 |
-
bias=True,
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 164 |
-
# x: (B, Q, H, W, D)
|
| 165 |
-
return self.deconv(x)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
class SEoMT(nn.Module):
|
| 169 |
-
def __init__(
|
| 170 |
-
self,
|
| 171 |
-
backbone: VisionTransformer,
|
| 172 |
-
num_classes: int,
|
| 173 |
-
# num_q: int,
|
| 174 |
-
num_blocks=4,
|
| 175 |
-
masked_attn_enabled=True,
|
| 176 |
-
return_only_final_layer=False,
|
| 177 |
-
upscale_output=True,
|
| 178 |
-
for_nnunet=False,
|
| 179 |
-
decoder=False,
|
| 180 |
-
):
|
| 181 |
-
super().__init__()
|
| 182 |
-
self.backbone = backbone
|
| 183 |
-
self.num_q = num_classes
|
| 184 |
-
self.num_blocks = num_blocks
|
| 185 |
-
self.masked_attn_enabled = masked_attn_enabled
|
| 186 |
-
self.return_only_final_layer = return_only_final_layer
|
| 187 |
-
self.upscale_output = upscale_output
|
| 188 |
-
self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
|
| 189 |
-
self.for_nnunet = for_nnunet
|
| 190 |
-
self.q = nn.Embedding(num_classes, self.backbone.embed_dim)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
self.mask_head = nn.Sequential(
|
| 194 |
-
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 195 |
-
nn.GELU(),
|
| 196 |
-
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 197 |
-
nn.GELU(),
|
| 198 |
-
nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
patch_size = self.backbone.patch_embed.patch_size
|
| 202 |
-
num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
|
| 203 |
-
|
| 204 |
-
# Build per-stage scale factors list
|
| 205 |
-
max_stages = max(num_upscale_stages)
|
| 206 |
-
upscale_blocks = []
|
| 207 |
-
for stage_idx in range(max_stages):
|
| 208 |
-
# for each dimension, upscale by 2 only if this dimension still has
|
| 209 |
-
# remaining upscales at this stage
|
| 210 |
-
scale_factors = tuple(
|
| 211 |
-
2 if stage_idx < num_upscale_stages[dim] else 1
|
| 212 |
-
for dim in range(len(patch_size))
|
| 213 |
-
)
|
| 214 |
-
upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
|
| 215 |
-
scale_factors=scale_factors))
|
| 216 |
-
|
| 217 |
-
self.upscale = nn.Sequential(*upscale_blocks)
|
| 218 |
-
|
| 219 |
-
def _predict(self, x: torch.Tensor, stage: int = None):
|
| 220 |
-
q = x[:, : self.num_q, :]
|
| 221 |
-
# print(stage)
|
| 222 |
-
# class_logits = self.class_head(q)
|
| 223 |
-
x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
|
| 224 |
-
x = x.transpose(1, 2).reshape(
|
| 225 |
-
x.shape[0], -1, *self.backbone.patch_embed.grid_size
|
| 226 |
-
)
|
| 227 |
-
mask_logits = torch.einsum(
|
| 228 |
-
"bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
return mask_logits
|
| 233 |
-
|
| 234 |
-
@torch.compiler.disable
|
| 235 |
-
def _disable_attn_mask(self, attn_mask, prob):
|
| 236 |
-
if prob < 1:
|
| 237 |
-
random_queries = (
|
| 238 |
-
torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
|
| 239 |
-
> prob
|
| 240 |
-
)
|
| 241 |
-
attn_mask[
|
| 242 |
-
:, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
|
| 243 |
-
][random_queries] = True
|
| 244 |
-
|
| 245 |
-
return attn_mask
|
| 246 |
-
|
| 247 |
-
def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope = None):
|
| 248 |
-
B, N, C = x.shape
|
| 249 |
-
|
| 250 |
-
q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
|
| 251 |
-
kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
|
| 252 |
-
k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 253 |
-
q, k = module.q_norm(q), module.k_norm(k)
|
| 254 |
-
|
| 255 |
-
if mask is not None:
|
| 256 |
-
mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
|
| 257 |
-
|
| 258 |
-
dropout_p = module.attn_drop.p if self.training else 0.0
|
| 259 |
-
|
| 260 |
-
if rope is not None:
|
| 261 |
-
if isinstance(rope, list):
|
| 262 |
-
rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
|
| 263 |
-
q, k = module.apply_rotary_pos_emb(q, k, rope)
|
| 264 |
-
|
| 265 |
-
if module.fused_attn:
|
| 266 |
-
x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
|
| 267 |
-
else:
|
| 268 |
-
attn = (q @ k.transpose(-2, -1)) * module.scale
|
| 269 |
-
if mask is not None:
|
| 270 |
-
attn = attn.masked_fill(~mask, float("-inf"))
|
| 271 |
-
attn = F.softmax(attn, dim=-1)
|
| 272 |
-
attn = module.attn_drop(attn)
|
| 273 |
-
x = attn @ v
|
| 274 |
-
|
| 275 |
-
x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
|
| 276 |
-
|
| 277 |
-
return x
|
| 278 |
-
|
| 279 |
-
def forward(self, x: torch.Tensor):
|
| 280 |
-
|
| 281 |
-
if self.for_nnunet: # swap data order, will be incoming at czyx - cxyz
|
| 282 |
-
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
| 283 |
-
|
| 284 |
-
self.backbone.patch_embed.set_input_size(x.shape[2:])
|
| 285 |
-
x = self.backbone.patch_embed(x)
|
| 286 |
-
x, rope = self.backbone._pos_embed(x)
|
| 287 |
-
x = self.backbone.patch_drop(x)
|
| 288 |
-
x = self.backbone.norm_pre(x)
|
| 289 |
-
attn_mask = None
|
| 290 |
-
mask_logits_per_layer = []
|
| 291 |
-
|
| 292 |
-
for i, block in enumerate(self.backbone.blocks):
|
| 293 |
-
if i == len(self.backbone.blocks) - self.num_blocks:
|
| 294 |
-
x = torch.cat(
|
| 295 |
-
(self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
|
| 296 |
-
)
|
| 297 |
-
|
| 298 |
-
if (
|
| 299 |
-
self.masked_attn_enabled
|
| 300 |
-
and i >= len(self.backbone.blocks) - self.num_blocks
|
| 301 |
-
):
|
| 302 |
-
mask_logits = self._predict(self.backbone.norm(x))
|
| 303 |
-
|
| 304 |
-
if self.for_nnunet:
|
| 305 |
-
# swap back to czyx
|
| 306 |
-
|
| 307 |
-
if self.upscale_output:
|
| 308 |
-
# Upscale to original input size / stage
|
| 309 |
-
|
| 310 |
-
stage = len(self.backbone.blocks) -i
|
| 311 |
-
if stage is not None:
|
| 312 |
-
input_size = tuple(
|
| 313 |
-
int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**(stage))
|
| 314 |
-
for dim in range(len(self.backbone.patch_embed.patch_size))
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
|
| 318 |
-
else:
|
| 319 |
-
mask_logits_per_layer.append(mask_logits.permute(0, 1, 4, 2, 3).contiguous())
|
| 320 |
-
else:
|
| 321 |
-
mask_logits_per_layer.append(mask_logits)
|
| 322 |
-
|
| 323 |
-
attn_mask = torch.ones(
|
| 324 |
-
x.shape[0],
|
| 325 |
-
x.shape[1],
|
| 326 |
-
x.shape[1],
|
| 327 |
-
dtype=torch.bool,
|
| 328 |
-
device=x.device,
|
| 329 |
-
)
|
| 330 |
-
interpolated = F.interpolate(
|
| 331 |
-
mask_logits,
|
| 332 |
-
self.backbone.patch_embed.grid_size,
|
| 333 |
-
mode="trilinear",
|
| 334 |
-
)
|
| 335 |
-
interpolated = interpolated.view(
|
| 336 |
-
interpolated.size(0), interpolated.size(1), -1
|
| 337 |
-
)
|
| 338 |
-
attn_mask[
|
| 339 |
-
:,
|
| 340 |
-
: self.num_q,
|
| 341 |
-
self.num_q + self.backbone.num_prefix_tokens :,
|
| 342 |
-
] = (
|
| 343 |
-
interpolated > 0
|
| 344 |
-
)
|
| 345 |
-
attn_mask = self._disable_attn_mask(
|
| 346 |
-
attn_mask,
|
| 347 |
-
self.attn_mask_probs[
|
| 348 |
-
i - len(self.backbone.blocks) + self.num_blocks
|
| 349 |
-
],
|
| 350 |
-
)
|
| 351 |
-
x = x + block.drop_path1(
|
| 352 |
-
block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope=rope))
|
| 353 |
-
)
|
| 354 |
-
x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
|
| 355 |
-
|
| 356 |
-
mask_logits = self._predict(self.backbone.norm(x))
|
| 357 |
-
if self.for_nnunet:
|
| 358 |
-
input_size = tuple(
|
| 359 |
-
int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**0)
|
| 360 |
-
for dim in range(len(self.backbone.patch_embed.patch_size))
|
| 361 |
-
)
|
| 362 |
-
mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
|
| 363 |
-
else:
|
| 364 |
-
mask_logits_per_layer.append(mask_logits)
|
| 365 |
-
|
| 366 |
-
if self.for_nnunet:
|
| 367 |
-
# return in reversed order for deep supervision
|
| 368 |
-
mask_logits_per_layer = mask_logits_per_layer[::-1]
|
| 369 |
-
return mask_logits_per_layer if not self.return_only_final_layer else mask_logits_per_layer[0]
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
if __name__ == "__main__":
|
| 373 |
-
from spectre.models import vit_large_patch16_128
|
| 374 |
-
|
| 375 |
-
model = SEoMT(
|
| 376 |
-
backbone=vit_large_patch16_128(pos_embed='rope',
|
| 377 |
-
rope_kwargs={
|
| 378 |
-
"base": 1000.0, # works for most 3D models
|
| 379 |
-
},),
|
| 380 |
-
num_classes=4,
|
| 381 |
-
num_blocks=4,
|
| 382 |
-
masked_attn_enabled=True,
|
| 383 |
-
return_only_final_layer=True,
|
| 384 |
-
for_nnunet=True,
|
| 385 |
-
upscale_output=True,
|
| 386 |
-
decoder=False,
|
| 387 |
-
)
|
| 388 |
-
# print number of parameters
|
| 389 |
-
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
| 390 |
-
|
| 391 |
-
x = torch.randn(2, 1, 64, 128, 128)
|
| 392 |
-
out = model(x)
|
| 393 |
-
for o in out:
|
| 394 |
-
print(o.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/models/upsample_anything.py
DELETED
|
@@ -1,319 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from torch.optim.lr_scheduler import LambdaLR
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def UPA(hr_image, lr_volume, device="cuda", use_amp=True):
|
| 9 |
-
"""
|
| 10 |
-
hr_image: numpy or torch [C,Hh,Wh,Dh]
|
| 11 |
-
lr_volume: torch [1,C,Hl,Wl,Dl]
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
hr = torch.as_tensor(hr_image).unsqueeze(0).float().to(device)
|
| 15 |
-
|
| 16 |
-
_, _, Hh, Wh, Dh = hr.shape
|
| 17 |
-
_, _, Hl, Wl, Dl = lr_volume.shape
|
| 18 |
-
scale = Hh // Hl
|
| 19 |
-
assert Wh // Wl == scale and Dh // Dl == scale, "Inconsistent scale factors"
|
| 20 |
-
|
| 21 |
-
lr_volume = lr_volume.to(device).float()
|
| 22 |
-
lr = F.interpolate(hr, scale_factor=1/scale, mode="trilinear", align_corners=False)
|
| 23 |
-
|
| 24 |
-
model = LearnablePixelwiseAnisoJBU3D(
|
| 25 |
-
Hl, Wl, Dl, scale=scale
|
| 26 |
-
).to(device)
|
| 27 |
-
|
| 28 |
-
model.train()
|
| 29 |
-
opt = torch.optim.Adam(model.parameters(), lr=1e-1)
|
| 30 |
-
max_steps = 350
|
| 31 |
-
gamma = (1e-9 / 1e-1) ** (1.0 / max_steps)
|
| 32 |
-
scheduler = LambdaLR(opt, lr_lambda=lambda step: gamma ** step)
|
| 33 |
-
scaler = torch.amp.GradScaler(device=device, enabled=use_amp)
|
| 34 |
-
|
| 35 |
-
for step in range(max_steps):
|
| 36 |
-
opt.zero_grad(set_to_none=True)
|
| 37 |
-
with torch.amp.autocast(device_type=device, enabled=use_amp):
|
| 38 |
-
pred = model(lr, hr)
|
| 39 |
-
loss = F.l1_loss(pred, hr)
|
| 40 |
-
|
| 41 |
-
scaler.scale(loss).backward()
|
| 42 |
-
scaler.step(opt)
|
| 43 |
-
scaler.update()
|
| 44 |
-
scheduler.step()
|
| 45 |
-
|
| 46 |
-
if step % 50 == 0:
|
| 47 |
-
print(f"step {step}: loss={loss.item():.5f}")
|
| 48 |
-
|
| 49 |
-
model.eval()
|
| 50 |
-
with torch.inference_mode(), \
|
| 51 |
-
torch.amp.autocast(device_type=device, enabled=use_amp, dtype=torch.float16):
|
| 52 |
-
out = model(lr_volume, hr)
|
| 53 |
-
|
| 54 |
-
return out
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
@torch.no_grad()
|
| 58 |
-
def _build_offsets_3d(R_max: int, device):
|
| 59 |
-
offs = torch.arange(-R_max, R_max + 1, device=device)
|
| 60 |
-
dX, dY, dZ = torch.meshgrid(offs, offs, offs, indexing="ij")
|
| 61 |
-
return (
|
| 62 |
-
dX.reshape(-1),
|
| 63 |
-
dY.reshape(-1),
|
| 64 |
-
dZ.reshape(-1),
|
| 65 |
-
) # [K]
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def gather_lr_scalar_3d(map_lr, Ui, Vi, Wi):
|
| 69 |
-
"""
|
| 70 |
-
map_lr: [1,1,Hl,Wl,Dl] or [Hl,Wl,Dl]
|
| 71 |
-
Ui,Vi,Wi: [Bn,Hh,Wh,Dh]
|
| 72 |
-
"""
|
| 73 |
-
Hl, Wl, Dl = map_lr.shape[-3:]
|
| 74 |
-
flat = Hl * Wl * Dl
|
| 75 |
-
idx = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
|
| 76 |
-
t = map_lr.view(flat)
|
| 77 |
-
vals = t.index_select(0, idx)
|
| 78 |
-
return vals.view(Ui.shape)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def gs_jbu_aniso_noparent_3d(
|
| 82 |
-
feat_lr, # [1,C,Hl,Wl,Dl]
|
| 83 |
-
guide_hr, # [1,G,Hh,Wh,Dh]
|
| 84 |
-
scale,
|
| 85 |
-
sigma_x_map,
|
| 86 |
-
sigma_y_map,
|
| 87 |
-
sigma_z_map,
|
| 88 |
-
sigma_r_map,
|
| 89 |
-
R_max=3,
|
| 90 |
-
alpha_dyn=2.0,
|
| 91 |
-
C_chunk=64,
|
| 92 |
-
Nn_chunk=125,
|
| 93 |
-
):
|
| 94 |
-
_, C, Hl, Wl, Dl = feat_lr.shape
|
| 95 |
-
_, _, Hh, Wh, Dh = guide_hr.shape
|
| 96 |
-
device = feat_lr.device
|
| 97 |
-
dtype_feat = feat_lr.dtype
|
| 98 |
-
|
| 99 |
-
# HR grid
|
| 100 |
-
x = torch.arange(Hh, device=device, dtype=torch.float32)
|
| 101 |
-
y = torch.arange(Wh, device=device, dtype=torch.float32)
|
| 102 |
-
z = torch.arange(Dh, device=device, dtype=torch.float32)
|
| 103 |
-
X, Y, Z = torch.meshgrid(x, y, z, indexing="ij")
|
| 104 |
-
|
| 105 |
-
u = (X + 0.5) / scale - 0.5
|
| 106 |
-
v = (Y + 0.5) / scale - 0.5
|
| 107 |
-
w = (Z + 0.5) / scale - 0.5
|
| 108 |
-
|
| 109 |
-
uc = torch.round(u).clamp(0, Hl - 1).long()
|
| 110 |
-
vc = torch.round(v).clamp(0, Wl - 1).long()
|
| 111 |
-
wc = torch.round(w).clamp(0, Dl - 1).long()
|
| 112 |
-
|
| 113 |
-
# Dynamic radius
|
| 114 |
-
sigma_eff = torch.maximum(
|
| 115 |
-
sigma_x_map,
|
| 116 |
-
torch.maximum(sigma_y_map, sigma_z_map),
|
| 117 |
-
)
|
| 118 |
-
sigma_eff_hr = F.interpolate(
|
| 119 |
-
sigma_eff, (Hh, Wh, Dh), mode="trilinear", align_corners=False
|
| 120 |
-
)
|
| 121 |
-
# sigma_eff_hr = sigma_eff_hr.squeeze(0).squeeze(0)
|
| 122 |
-
R_map = torch.ceil(alpha_dyn * sigma_eff_hr).clamp(1, R_max).long()
|
| 123 |
-
|
| 124 |
-
dX_all, dY_all, dZ_all = _build_offsets_3d(R_max, device)
|
| 125 |
-
|
| 126 |
-
num = torch.zeros(C, Hh, Wh, Dh, device=device, dtype=torch.float32)
|
| 127 |
-
den = torch.zeros(Hh, Wh, Dh, device=device, dtype=torch.float32)
|
| 128 |
-
m = torch.full((Hh, Wh, Dh), -1e9, device=device, dtype=torch.float32)
|
| 129 |
-
|
| 130 |
-
feat_flat = feat_lr[0].permute(1, 2, 3, 0).reshape(-1, C)
|
| 131 |
-
guide_lr = F.interpolate(
|
| 132 |
-
guide_hr, (Hl, Wl, Dl), mode="trilinear", align_corners=False
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
for n0 in range(0, len(dX_all), Nn_chunk):
|
| 136 |
-
dX = dX_all[n0:n0+Nn_chunk][:, None, None, None]
|
| 137 |
-
dY = dY_all[n0:n0+Nn_chunk][:, None, None, None]
|
| 138 |
-
dZ = dZ_all[n0:n0+Nn_chunk][:, None, None, None]
|
| 139 |
-
|
| 140 |
-
Ui = torch.clamp(uc.unsqueeze(0) + dX, 0, Hl - 1)
|
| 141 |
-
Vi = torch.clamp(vc.unsqueeze(0) + dY, 0, Wl - 1)
|
| 142 |
-
Wi = torch.clamp(wc.unsqueeze(0) + dZ, 0, Dl - 1)
|
| 143 |
-
|
| 144 |
-
# mask = (dX**2 + dY**2 + dZ**2 <= R_map[None, ...] ** 2)
|
| 145 |
-
mask = (dX**2 + dY**2 + dZ**2 <= R_map**2).squeeze(0).squeeze(0)
|
| 146 |
-
|
| 147 |
-
cx = (Ui.float() + 0.5) * scale - 0.5
|
| 148 |
-
cy = (Vi.float() + 0.5) * scale - 0.5
|
| 149 |
-
cz = (Wi.float() + 0.5) * scale - 0.5
|
| 150 |
-
|
| 151 |
-
dx = X.unsqueeze(0) - cx
|
| 152 |
-
dy = Y.unsqueeze(0) - cy
|
| 153 |
-
dz = Z.unsqueeze(0) - cz
|
| 154 |
-
|
| 155 |
-
sx = gather_lr_scalar_3d(sigma_x_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 156 |
-
sy = gather_lr_scalar_3d(sigma_y_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 157 |
-
sz = gather_lr_scalar_3d(sigma_z_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 158 |
-
sr = gather_lr_scalar_3d(sigma_r_map, Ui, Vi, Wi).clamp_min(1e-6)
|
| 159 |
-
|
| 160 |
-
log_ws = (
|
| 161 |
-
-(dx**2)/(2*sx**2)
|
| 162 |
-
-(dy**2)/(2*sy**2)
|
| 163 |
-
-(dz**2)/(2*sz**2)
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
diff2 = 0.0
|
| 167 |
-
for g in range(guide_hr.shape[1]):
|
| 168 |
-
g0 = gather_lr_scalar_3d(guide_lr[0, g], Ui, Vi, Wi)
|
| 169 |
-
diff2 += (guide_hr[0, g] - g0) ** 2
|
| 170 |
-
|
| 171 |
-
log_wr = -diff2 / (2 * sr**2 + 1e-8)
|
| 172 |
-
log_w = torch.where(mask, log_ws + log_wr, -1e9)
|
| 173 |
-
|
| 174 |
-
m_chunk = log_w.max(dim=0).values
|
| 175 |
-
m_new = torch.maximum(m, m_chunk)
|
| 176 |
-
|
| 177 |
-
scale_old = torch.exp(m - m_new)
|
| 178 |
-
num *= scale_old
|
| 179 |
-
den *= scale_old
|
| 180 |
-
|
| 181 |
-
w = torch.exp(log_w - m_new)
|
| 182 |
-
den += w.sum(0)
|
| 183 |
-
|
| 184 |
-
idx_flat = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
|
| 185 |
-
|
| 186 |
-
for c0 in range(0, C, C_chunk):
|
| 187 |
-
c1 = min(c0 + C_chunk, C)
|
| 188 |
-
f = feat_flat.index_select(0, idx_flat)[:, c0:c1]
|
| 189 |
-
f = f.view(w.shape + (c1 - c0,))
|
| 190 |
-
num[c0:c1] += (f * w[..., None]).sum(0).permute(3, 0, 1, 2)
|
| 191 |
-
|
| 192 |
-
m = m_new
|
| 193 |
-
|
| 194 |
-
out = (num / den.clamp_min(1e-8)).unsqueeze(0)
|
| 195 |
-
return out.to(dtype_feat)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
class LearnablePixelwiseAnisoJBU3D(nn.Module):
|
| 199 |
-
def __init__(
|
| 200 |
-
self,
|
| 201 |
-
Hl,
|
| 202 |
-
Wl,
|
| 203 |
-
Dl,
|
| 204 |
-
scale,
|
| 205 |
-
init_sigma=1.5,
|
| 206 |
-
init_sigma_r=0.1,
|
| 207 |
-
R_max=3,
|
| 208 |
-
alpha_dyn=2.0,
|
| 209 |
-
):
|
| 210 |
-
super().__init__()
|
| 211 |
-
self.scale = scale
|
| 212 |
-
self.R_max = R_max
|
| 213 |
-
self.alpha_dyn = alpha_dyn
|
| 214 |
-
|
| 215 |
-
self.sx_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
|
| 216 |
-
self.sy_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
|
| 217 |
-
self.sz_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
|
| 218 |
-
self.sr_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma_r)))
|
| 219 |
-
|
| 220 |
-
def forward(self, feat_lr, guide_hr):
|
| 221 |
-
return gs_jbu_aniso_noparent_3d(
|
| 222 |
-
feat_lr,
|
| 223 |
-
guide_hr,
|
| 224 |
-
self.scale,
|
| 225 |
-
torch.exp(self.sx_raw),
|
| 226 |
-
torch.exp(self.sy_raw),
|
| 227 |
-
torch.exp(self.sz_raw),
|
| 228 |
-
torch.exp(self.sr_raw),
|
| 229 |
-
R_max=self.R_max,
|
| 230 |
-
alpha_dyn=self.alpha_dyn,
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
if __name__ == "__main__":
|
| 235 |
-
import argparse
|
| 236 |
-
|
| 237 |
-
import numpy as np
|
| 238 |
-
import nibabel as nib
|
| 239 |
-
import monai.transforms as transforms
|
| 240 |
-
|
| 241 |
-
parser = argparse.ArgumentParser()
|
| 242 |
-
parser.add_argument("--image_path", type=str, required=True)
|
| 243 |
-
parser.add_argument("--mask_path", type=str, required=True)
|
| 244 |
-
parser.add_argument("--device", type=str, default="cuda")
|
| 245 |
-
parser.add_argument("--use_amp", action="store_true")
|
| 246 |
-
args = parser.parse_args()
|
| 247 |
-
|
| 248 |
-
transform = transforms.Compose([
|
| 249 |
-
transforms.LoadImaged(keys=("image", "mask")),
|
| 250 |
-
transforms.EnsureChannelFirstd(keys=("image", "mask"), channel_dim="no_channel"),
|
| 251 |
-
transforms.ScaleIntensityRanged(
|
| 252 |
-
keys=("image",),
|
| 253 |
-
a_min=-150,
|
| 254 |
-
a_max=250,
|
| 255 |
-
b_min=0.0,
|
| 256 |
-
b_max=1.0,
|
| 257 |
-
clip=True,
|
| 258 |
-
),
|
| 259 |
-
transforms.Orientationd(keys=("image", "mask"), axcodes="RAS"),
|
| 260 |
-
transforms.RandWeightedCropd(
|
| 261 |
-
keys=("image", "mask"),
|
| 262 |
-
w_key="mask",
|
| 263 |
-
spatial_size=(128, 128, 64),
|
| 264 |
-
num_samples=1,
|
| 265 |
-
),
|
| 266 |
-
transforms.CopyItemsd(keys=("mask"), times=1, names=("mask_low_res")),
|
| 267 |
-
transforms.Resized(keys=("mask_low_res"), spatial_size=(16, 16, 8), mode="nearest", align_corners=False)
|
| 268 |
-
])
|
| 269 |
-
sample = transform({
|
| 270 |
-
"image": args.image_path,
|
| 271 |
-
"mask": args.mask_path,
|
| 272 |
-
})[0]
|
| 273 |
-
|
| 274 |
-
nib.save(
|
| 275 |
-
nib.Nifti1Image(
|
| 276 |
-
(F.interpolate(sample["mask_low_res"].unsqueeze(0), size=(128, 128, 64), mode="nearest").squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8)),
|
| 277 |
-
affine=np.eye(4),
|
| 278 |
-
),
|
| 279 |
-
"mask_low_res_upscaled.nii.gz",
|
| 280 |
-
)
|
| 281 |
-
|
| 282 |
-
sample["mask_low_res"] = F.one_hot(
|
| 283 |
-
sample["mask_low_res"].long().squeeze(0), num_classes=4,
|
| 284 |
-
).permute(3, 0, 1, 2).unsqueeze(0).float()
|
| 285 |
-
|
| 286 |
-
print(sample["mask_low_res"].shape)
|
| 287 |
-
|
| 288 |
-
mask_out = UPA(
|
| 289 |
-
sample["image"],
|
| 290 |
-
sample["mask_low_res"],
|
| 291 |
-
device=args.device,
|
| 292 |
-
use_amp=args.use_amp,
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
mask_out = mask_out.argmax(dim=1, keepdim=True)
|
| 296 |
-
|
| 297 |
-
nib.save(
|
| 298 |
-
nib.Nifti1Image(
|
| 299 |
-
(sample["image"] * 255).squeeze(0).cpu().numpy().astype(np.uint8),
|
| 300 |
-
affine=np.eye(4),
|
| 301 |
-
),
|
| 302 |
-
"image.nii.gz",
|
| 303 |
-
)
|
| 304 |
-
nib.save(
|
| 305 |
-
nib.Nifti1Image(
|
| 306 |
-
sample["mask"].squeeze(0).cpu().numpy().astype(np.uint8),
|
| 307 |
-
affine=np.eye(4),
|
| 308 |
-
),
|
| 309 |
-
"mask.nii.gz",
|
| 310 |
-
)
|
| 311 |
-
torch.save(mask_out.squeeze(0).squeeze(0).cpu(), "upsampled_mask.pt")
|
| 312 |
-
nib.save(
|
| 313 |
-
nib.Nifti1Image(
|
| 314 |
-
mask_out.squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8),
|
| 315 |
-
affine=np.eye(4),
|
| 316 |
-
),
|
| 317 |
-
"upsampled_mask.nii.gz",
|
| 318 |
-
)
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/checkpointing.py
DELETED
|
@@ -1,238 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import random
|
| 3 |
-
import warnings
|
| 4 |
-
from typing import Optional, Any
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch.distributed as dist
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def _get_local_rng_state() -> dict:
|
| 12 |
-
"""Return a picklable dict with local RNG states (cpu & cuda, numpy, random)."""
|
| 13 |
-
state = {
|
| 14 |
-
"torch": torch.get_rng_state().cpu(),
|
| 15 |
-
"numpy": np.random.get_state(),
|
| 16 |
-
"random": random.getstate(),
|
| 17 |
-
}
|
| 18 |
-
|
| 19 |
-
if torch.cuda.is_available():
|
| 20 |
-
# make sure CUDA states are stored on CPU so they are picklable
|
| 21 |
-
cuda_states = [s.cpu() for s in torch.cuda.get_rng_state_all()]
|
| 22 |
-
state["cuda"] = cuda_states
|
| 23 |
-
else:
|
| 24 |
-
state["cuda"] = None
|
| 25 |
-
|
| 26 |
-
return state
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def _set_local_rng_state(state: dict) -> None:
|
| 30 |
-
"""Set local RNG states from the dict produced by _get_local_rng_state()."""
|
| 31 |
-
if state is None:
|
| 32 |
-
return
|
| 33 |
-
|
| 34 |
-
if "torch" in state and state["torch"] is not None:
|
| 35 |
-
torch.set_rng_state(state["torch"])
|
| 36 |
-
if "cuda" in state and state["cuda"] is not None and torch.cuda.is_available():
|
| 37 |
-
try:
|
| 38 |
-
# move back to CUDA tensors for this process and set them
|
| 39 |
-
cuda_states = [s.cuda() for s in state["cuda"]]
|
| 40 |
-
torch.cuda.set_rng_state_all(cuda_states)
|
| 41 |
-
except Exception:
|
| 42 |
-
# fallback: try setting per-device RNG if set_rng_state_all fails
|
| 43 |
-
for i, s in enumerate(state["cuda"]):
|
| 44 |
-
try:
|
| 45 |
-
torch.cuda.set_rng_state(s.cuda(), device=i)
|
| 46 |
-
except Exception:
|
| 47 |
-
# ignore if device mismatch
|
| 48 |
-
pass
|
| 49 |
-
|
| 50 |
-
if "numpy" in state and state["numpy"] is not None:
|
| 51 |
-
np.random.set_state(state["numpy"])
|
| 52 |
-
if "random" in state and state["random"] is not None:
|
| 53 |
-
random.setstate(state["random"])
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def save_state(ckpt_path: str, epoch: Optional[int] = None, **named_objects: Any) -> None:
|
| 57 |
-
"""
|
| 58 |
-
Save a checkpoint that includes:
|
| 59 |
-
- epoch (optional)
|
| 60 |
-
- state_dicts for provided named_objects
|
| 61 |
-
- rng_states: list of per-rank RNG dictionaries (one entry per world rank)
|
| 62 |
-
|
| 63 |
-
If torch.distributed is initialized the RNG states from all ranks are gathered and
|
| 64 |
-
stored in checkpoint["rng_states"] (list indexed by rank). Only rank 0 writes the file.
|
| 65 |
-
In single-process mode the checkpoint contains a single-item rng_states list.
|
| 66 |
-
"""
|
| 67 |
-
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
|
| 68 |
-
|
| 69 |
-
# prepare local RNG state
|
| 70 |
-
local_rng = _get_local_rng_state()
|
| 71 |
-
|
| 72 |
-
# distributed path: gather RNG states from all ranks
|
| 73 |
-
if dist.is_available() and dist.is_initialized():
|
| 74 |
-
rank = dist.get_rank()
|
| 75 |
-
world_size = dist.get_world_size()
|
| 76 |
-
all_states = [None] * world_size
|
| 77 |
-
# gather python objects (picklable)
|
| 78 |
-
dist.all_gather_object(all_states, local_rng)
|
| 79 |
-
|
| 80 |
-
# only rank 0 writes the checkpoint file
|
| 81 |
-
if rank == 0:
|
| 82 |
-
checkpoint = {}
|
| 83 |
-
if epoch is not None:
|
| 84 |
-
checkpoint["epoch"] = epoch
|
| 85 |
-
checkpoint["rng_states"] = all_states
|
| 86 |
-
|
| 87 |
-
# save provided objects' state_dicts (rank 0's local state_dicts)
|
| 88 |
-
for name, obj in named_objects.items():
|
| 89 |
-
checkpoint[name] = obj.state_dict()
|
| 90 |
-
|
| 91 |
-
torch.save(checkpoint, ckpt_path)
|
| 92 |
-
|
| 93 |
-
# ensure everyone waits until rank 0 finished writing
|
| 94 |
-
dist.barrier()
|
| 95 |
-
|
| 96 |
-
else:
|
| 97 |
-
# single-process fallback
|
| 98 |
-
checkpoint = {}
|
| 99 |
-
if epoch is not None:
|
| 100 |
-
checkpoint["epoch"] = epoch
|
| 101 |
-
checkpoint["rng_states"] = [local_rng]
|
| 102 |
-
for name, obj in named_objects.items():
|
| 103 |
-
checkpoint[name] = obj.state_dict()
|
| 104 |
-
torch.save(checkpoint, ckpt_path)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def load_state(ckpt_path: str, **named_objects: Any) -> int:
|
| 108 |
-
"""
|
| 109 |
-
Load checkpoint saved by save_state.
|
| 110 |
-
|
| 111 |
-
- Each process loads the same file and restores its own RNG state (checkpoint['rng_states'][rank]).
|
| 112 |
-
- Named objects that exist in the checkpoint will have their state_dict loaded.
|
| 113 |
-
- Returns epoch (int) if present, otherwise 0.
|
| 114 |
-
"""
|
| 115 |
-
if not os.path.isfile(ckpt_path):
|
| 116 |
-
warnings.warn(f"Checkpoint file not found: {ckpt_path}")
|
| 117 |
-
return 0
|
| 118 |
-
|
| 119 |
-
# load on all ranks (shared FS assumed)
|
| 120 |
-
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 121 |
-
epoch = checkpoint.get("epoch", 0)
|
| 122 |
-
|
| 123 |
-
# load state_dicts into provided objects
|
| 124 |
-
for name, obj in named_objects.items():
|
| 125 |
-
if name in checkpoint:
|
| 126 |
-
try:
|
| 127 |
-
obj.load_state_dict(checkpoint[name])
|
| 128 |
-
except Exception as e:
|
| 129 |
-
warnings.warn(f"Failed to load state_dict for '{name}': {e}")
|
| 130 |
-
else:
|
| 131 |
-
warnings.warn(f"No state_dict found for '{name}' in checkpoint.")
|
| 132 |
-
|
| 133 |
-
# restore this rank's RNG state
|
| 134 |
-
rng_states = checkpoint.get("rng_states", None)
|
| 135 |
-
if rng_states is not None:
|
| 136 |
-
if dist.is_available() and dist.is_initialized():
|
| 137 |
-
rank = dist.get_rank()
|
| 138 |
-
if rank < len(rng_states):
|
| 139 |
-
my_state = rng_states[rank]
|
| 140 |
-
else:
|
| 141 |
-
my_state = None
|
| 142 |
-
else:
|
| 143 |
-
# single-process file: first element
|
| 144 |
-
my_state = rng_states[0] if len(rng_states) > 0 else None
|
| 145 |
-
|
| 146 |
-
try:
|
| 147 |
-
_set_local_rng_state(my_state)
|
| 148 |
-
except Exception as e:
|
| 149 |
-
warnings.warn(f"Failed to restore RNG state: {e}")
|
| 150 |
-
|
| 151 |
-
else:
|
| 152 |
-
warnings.warn("No 'rng_states' found in checkpoint; RNGs not restored.")
|
| 153 |
-
|
| 154 |
-
return epoch
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def extract_model_from_checkpoint_dinov2(checkpoint_path: str):
|
| 158 |
-
# Load the checkpoint
|
| 159 |
-
checkpoint = torch.load(
|
| 160 |
-
checkpoint_path,
|
| 161 |
-
weights_only=False,
|
| 162 |
-
map_location="cpu"
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
# Get model state dict
|
| 166 |
-
model_state = checkpoint.get("model", checkpoint)
|
| 167 |
-
|
| 168 |
-
# Create output folder
|
| 169 |
-
output_dir = str(checkpoint_path).replace(".pt", "")
|
| 170 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 171 |
-
|
| 172 |
-
# Quick check: compare the parameters of head_teacher_ibot vs head_teacher_dino
|
| 173 |
-
teacher_dino_keys = [k for k in model_state.keys() if k.startswith("head_teacher_dino.")]
|
| 174 |
-
teacher_ibot_keys = [k for k in model_state.keys() if k.startswith("head_teacher_ibot.")]
|
| 175 |
-
|
| 176 |
-
ibot_separate = True
|
| 177 |
-
if teacher_dino_keys and teacher_ibot_keys:
|
| 178 |
-
if all(torch.equal(model_state[dino_key], model_state[ibot_key]) \
|
| 179 |
-
for dino_key, ibot_key in zip(teacher_dino_keys, teacher_ibot_keys)):
|
| 180 |
-
ibot_separate = False # Same weights → no separate ibot head
|
| 181 |
-
|
| 182 |
-
# Define the components to save
|
| 183 |
-
components = {
|
| 184 |
-
"backbone_teacher.pt": "backbone_teacher.vit",
|
| 185 |
-
"backbone_student.pt": "backbone_student.vit",
|
| 186 |
-
"head_student_dino.pt": "head_student_dino",
|
| 187 |
-
"head_teacher_dino.pt": "head_teacher_dino"
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
# Add ibot heads only if separate
|
| 191 |
-
if ibot_separate:
|
| 192 |
-
components["head_student_ibot.pt"] = "head_student_ibot"
|
| 193 |
-
components["head_teacher_ibot.pt"] = "head_teacher_ibot"
|
| 194 |
-
|
| 195 |
-
# Extract and save each component
|
| 196 |
-
for filename, key in components.items():
|
| 197 |
-
sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
|
| 198 |
-
if not sub_state_dict:
|
| 199 |
-
print(f"[WARNING] No parameters found for {key}, skipping...")
|
| 200 |
-
continue
|
| 201 |
-
torch.save(sub_state_dict, os.path.join(output_dir, filename))
|
| 202 |
-
|
| 203 |
-
print(f"Components extracted to: {output_dir}")
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
def extract_model_from_checkpoint_siglip(checkpoint_path: str):
|
| 207 |
-
# Load the checkpoint
|
| 208 |
-
checkpoint = torch.load(
|
| 209 |
-
checkpoint_path,
|
| 210 |
-
weights_only=False,
|
| 211 |
-
map_location="cpu",
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
# Get model state dict
|
| 215 |
-
model_state = checkpoint.get("model", checkpoint)
|
| 216 |
-
|
| 217 |
-
# Create output folder
|
| 218 |
-
output_dir = str(checkpoint_path).replace(".pt", "")
|
| 219 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 220 |
-
|
| 221 |
-
# Define the components to save
|
| 222 |
-
components = {
|
| 223 |
-
"backbone_image.pt": "backbone_image",
|
| 224 |
-
"backbone_text.pt": "backbone_text",
|
| 225 |
-
"feature_comb_image.pt": "feature_comb_image",
|
| 226 |
-
"projection_image.pt": "projection_image",
|
| 227 |
-
"projection_text.pt": "projection_text"
|
| 228 |
-
}
|
| 229 |
-
|
| 230 |
-
# Extract and save each component
|
| 231 |
-
for filename, key in components.items():
|
| 232 |
-
sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
|
| 233 |
-
if not sub_state_dict:
|
| 234 |
-
print(f"[WARNING] No parameters found for {key}, skipping...")
|
| 235 |
-
continue
|
| 236 |
-
torch.save(sub_state_dict, os.path.join(output_dir, filename))
|
| 237 |
-
|
| 238 |
-
print(f"Components extracted to: {output_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/collate.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
from typing import List, Callable, Optional
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
MONAI_IMPORT_ERROR = None
|
| 6 |
-
try:
|
| 7 |
-
from monai.data import list_data_collate
|
| 8 |
-
except ImportError as e:
|
| 9 |
-
list_data_collate = lambda x: x # type: ignore
|
| 10 |
-
MONAI_IMPORT_ERROR = e
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def extended_collate_dino(samples_list: List) -> dict:
|
| 14 |
-
"""
|
| 15 |
-
Applies MONAI's list_data_collate first and then extends it with DINOv2 masking logic.
|
| 16 |
-
|
| 17 |
-
Args:
|
| 18 |
-
samples_list: List of samples containing 'global_crops' and 'local_crops'.
|
| 19 |
-
mask_ratio: Tuple defining the range of masking ratios.
|
| 20 |
-
mask_probability: Probability of applying masking.
|
| 21 |
-
dtype: Data type to cast the collated tensors.
|
| 22 |
-
n_tokens: Number of tokens for masking.
|
| 23 |
-
mask_generator: Function to generate masks.
|
| 24 |
-
|
| 25 |
-
Returns:
|
| 26 |
-
A dictionary with collated global/local crops and corresponding masks.
|
| 27 |
-
"""
|
| 28 |
-
if MONAI_IMPORT_ERROR is not None:
|
| 29 |
-
raise ImportError(
|
| 30 |
-
"MONAI is required to use extended_collate_dino but not installed. "
|
| 31 |
-
"Please install MONAI to use this collate function."
|
| 32 |
-
) from MONAI_IMPORT_ERROR
|
| 33 |
-
|
| 34 |
-
# Apply MONAI's list_data_collate
|
| 35 |
-
collated_data = list_data_collate(samples_list)
|
| 36 |
-
|
| 37 |
-
# Extract crops
|
| 38 |
-
global_views = torch.cat(collated_data["image_global_views"], dim=0)
|
| 39 |
-
local_views = torch.cat(collated_data["image_local_views"], dim=0)
|
| 40 |
-
|
| 41 |
-
return {
|
| 42 |
-
"global_views": global_views,
|
| 43 |
-
"local_views": local_views,
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def extended_collate_siglip(
|
| 48 |
-
samples_list: List,
|
| 49 |
-
tokenizer: Optional[Callable] = None,
|
| 50 |
-
tokenizer_padding: bool = True,
|
| 51 |
-
tokenizer_truncation: bool = True,
|
| 52 |
-
tokenizer_max_length: Optional[int] = 1024,
|
| 53 |
-
return_filenames: bool = False
|
| 54 |
-
) -> dict:
|
| 55 |
-
"""
|
| 56 |
-
Applies SigLIP collate and then extends it with tokenization logic.
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
samples_list: List of samples containing 'image' and 'report'.
|
| 60 |
-
tokenizer: Tokenizer function to apply on the reports.
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
A dictionary with collated images and tokenized text.
|
| 64 |
-
"""
|
| 65 |
-
if MONAI_IMPORT_ERROR is not None:
|
| 66 |
-
raise ImportError(
|
| 67 |
-
"MONAI is required to use extended_collate_siglip but not installed. "
|
| 68 |
-
"Please install MONAI to use this collate function."
|
| 69 |
-
) from MONAI_IMPORT_ERROR
|
| 70 |
-
|
| 71 |
-
collated_data = list_data_collate(samples_list)
|
| 72 |
-
|
| 73 |
-
if return_filenames:
|
| 74 |
-
if "image" in collated_data.keys():
|
| 75 |
-
if (
|
| 76 |
-
hasattr(samples_list[0]["image"].data, "meta")
|
| 77 |
-
and "filename_or_obj" in samples_list[0]["image"].data.meta
|
| 78 |
-
):
|
| 79 |
-
collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
|
| 80 |
-
|
| 81 |
-
if tokenizer is not None and "report" in collated_data.keys():
|
| 82 |
-
tokenizer_output = tokenizer.batch_encode_plus(
|
| 83 |
-
collated_data["report"],
|
| 84 |
-
add_special_tokens=True,
|
| 85 |
-
padding=tokenizer_padding,
|
| 86 |
-
truncation=tokenizer_truncation,
|
| 87 |
-
max_length=tokenizer_max_length,
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
collated_data["input_ids"] = torch.tensor(tokenizer_output["input_ids"])
|
| 91 |
-
collated_data["attention_mask"] = torch.tensor(tokenizer_output["attention_mask"])
|
| 92 |
-
|
| 93 |
-
return collated_data
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def collate_add_filenames(samples_list: List) -> dict:
|
| 97 |
-
"""
|
| 98 |
-
Applies MONAI's list_data_collate and adds filenames to the collated output.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
samples_list: List of samples containing 'image' with metadata.
|
| 102 |
-
Returns:
|
| 103 |
-
A dictionary with collated images and filenames.
|
| 104 |
-
"""
|
| 105 |
-
if MONAI_IMPORT_ERROR is not None:
|
| 106 |
-
raise ImportError(
|
| 107 |
-
"MONAI is required to use collate_add_filenames but not installed. "
|
| 108 |
-
"Please install MONAI to use this collate function."
|
| 109 |
-
) from MONAI_IMPORT_ERROR
|
| 110 |
-
|
| 111 |
-
collated_data = list_data_collate(samples_list)
|
| 112 |
-
|
| 113 |
-
if "image" in collated_data.keys():
|
| 114 |
-
if (
|
| 115 |
-
hasattr(samples_list[0]["image"].data, "meta")
|
| 116 |
-
and "filename_or_obj" in samples_list[0]["image"].data.meta
|
| 117 |
-
):
|
| 118 |
-
collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
|
| 119 |
-
|
| 120 |
-
return collated_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/config.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import math
|
| 3 |
-
|
| 4 |
-
from spectre.utils import _utils, distributed
|
| 5 |
-
|
| 6 |
-
OMEGACONF_IMPORT_ERROR = None
|
| 7 |
-
try:
|
| 8 |
-
from omegaconf import OmegaConf
|
| 9 |
-
except ImportError as e:
|
| 10 |
-
OmegaConf = None # type: ignore
|
| 11 |
-
OMEGACONF_IMPORT_ERROR = e
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def apply_scaling_rules_to_cfg(cfg):
|
| 15 |
-
"""
|
| 16 |
-
Apply learing rate scaling rules to the configuration object.
|
| 17 |
-
"""
|
| 18 |
-
base_lr = cfg.optim.base_lr
|
| 19 |
-
cfg.optim.lr = base_lr
|
| 20 |
-
|
| 21 |
-
# Apply scaling rules
|
| 22 |
-
if cfg.optim.scaling_rule == "constant":
|
| 23 |
-
return cfg
|
| 24 |
-
|
| 25 |
-
try:
|
| 26 |
-
scaling_type, ref_batch_size = cfg.optim.scaling_rule.split("_wrt_")
|
| 27 |
-
ref_batch_size = float(ref_batch_size)
|
| 28 |
-
except ValueError:
|
| 29 |
-
raise NotImplementedError(f"Unknown scaling rule: {cfg.optim.scaling_rule}")
|
| 30 |
-
|
| 31 |
-
scale_factor = cfg.train.batch_size_per_gpu * distributed.get_global_size()
|
| 32 |
-
scale_factor /= ref_batch_size
|
| 33 |
-
scale_factor *= cfg.train.grad_accum_steps
|
| 34 |
-
|
| 35 |
-
if scaling_type == "sqrt":
|
| 36 |
-
cfg.optim.lr *= math.sqrt(scale_factor)
|
| 37 |
-
elif scaling_type == "linear":
|
| 38 |
-
cfg.optim.lr *= scale_factor
|
| 39 |
-
else:
|
| 40 |
-
raise NotImplementedError(f"Unsupported scaling type: {scaling_type}")
|
| 41 |
-
|
| 42 |
-
return cfg
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def write_config(cfg, output_dir, name="config.yaml"):
|
| 46 |
-
if OMEGACONF_IMPORT_ERROR is not None:
|
| 47 |
-
raise ImportError(
|
| 48 |
-
"OmegaConf is required to use write_config but not installed. "
|
| 49 |
-
"Please install OmegaConf to use this function."
|
| 50 |
-
) from OMEGACONF_IMPORT_ERROR
|
| 51 |
-
|
| 52 |
-
saved_cfg_path = os.path.join(output_dir, name)
|
| 53 |
-
with open(saved_cfg_path, "w") as f:
|
| 54 |
-
OmegaConf.save(config=cfg, f=f)
|
| 55 |
-
return saved_cfg_path
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def get_cfg_from_args(args, default_config):
|
| 59 |
-
if OMEGACONF_IMPORT_ERROR is not None:
|
| 60 |
-
raise ImportError(
|
| 61 |
-
"OmegaConf is required to use get_cfg_from_args but not installed. "
|
| 62 |
-
"Please install OmegaConf to use this function."
|
| 63 |
-
) from OMEGACONF_IMPORT_ERROR
|
| 64 |
-
|
| 65 |
-
args.output_dir = os.path.abspath(args.output_dir)
|
| 66 |
-
args.opts = [] if args.opts is None else args.opts
|
| 67 |
-
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 68 |
-
default_cfg = OmegaConf.create(default_config)
|
| 69 |
-
cfg = OmegaConf.load(args.config_file)
|
| 70 |
-
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 71 |
-
return cfg
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def random_seed(args):
|
| 75 |
-
seed = getattr(args, "seed", 0)
|
| 76 |
-
rank = distributed.get_global_rank()
|
| 77 |
-
|
| 78 |
-
_utils.fix_random_seeds(seed + rank)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def setup(args, default_config):
|
| 82 |
-
"""
|
| 83 |
-
Create configs and perform basic setups.
|
| 84 |
-
"""
|
| 85 |
-
cfg = get_cfg_from_args(args, default_config)
|
| 86 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 87 |
-
random_seed(args)
|
| 88 |
-
accelerator = distributed.init_distributed(cfg)
|
| 89 |
-
apply_scaling_rules_to_cfg(cfg)
|
| 90 |
-
write_config(cfg, args.output_dir)
|
| 91 |
-
return cfg, accelerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/dataloader.py
DELETED
|
@@ -1,126 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
import os
|
| 3 |
-
from typing import Union, Callable, Optional, List
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch.utils.data import ConcatDataset
|
| 7 |
-
|
| 8 |
-
MONAI_IMPORT_ERROR = None
|
| 9 |
-
try:
|
| 10 |
-
import monai.data as data
|
| 11 |
-
except ImportError as e:
|
| 12 |
-
data = None # type: ignore
|
| 13 |
-
MONAI_IMPORT_ERROR = e
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def get_dataloader(
|
| 18 |
-
datasets: Union[str, List[str]],
|
| 19 |
-
data_dir: str,
|
| 20 |
-
include_reports: bool = False,
|
| 21 |
-
include_labels: bool = False,
|
| 22 |
-
cache_dataset: bool = False,
|
| 23 |
-
cache_dir: Optional[str] = None,
|
| 24 |
-
use_gds: bool = False,
|
| 25 |
-
transform: Optional[Callable] = None,
|
| 26 |
-
fraction: float = 1.0,
|
| 27 |
-
batch_size: int = 64,
|
| 28 |
-
num_workers: int = 4,
|
| 29 |
-
pin_memory: bool = True,
|
| 30 |
-
shuffle: bool = True,
|
| 31 |
-
collate_fn: Optional[Callable] = None,
|
| 32 |
-
drop_last: bool = True,
|
| 33 |
-
persistent_workers: bool = True,
|
| 34 |
-
use_thread: bool = False,
|
| 35 |
-
) -> "DataLoader":
|
| 36 |
-
"""
|
| 37 |
-
Get dataloader for training.
|
| 38 |
-
"""
|
| 39 |
-
if MONAI_IMPORT_ERROR is not None:
|
| 40 |
-
raise ImportError(
|
| 41 |
-
"MONAI is required to use get_dataloader but not installed. "
|
| 42 |
-
"Please install MONAI to use this function."
|
| 43 |
-
) from MONAI_IMPORT_ERROR
|
| 44 |
-
|
| 45 |
-
if isinstance(datasets, str):
|
| 46 |
-
datasets = [datasets]
|
| 47 |
-
|
| 48 |
-
# Validate constraints
|
| 49 |
-
if include_reports:
|
| 50 |
-
assert set(datasets).issubset({"ct_rate", "merlin", "inspect"}), \
|
| 51 |
-
"When include_reports=True, only 'ct_rate', 'merlin', and 'inspect' are allowed."
|
| 52 |
-
if include_labels:
|
| 53 |
-
assert set(datasets).issubset({"abdomen_atlas", "abdomenct_1k"}), \
|
| 54 |
-
"When include_labels=True, only 'abdomen_atlas' and 'abdomenct_1k' are allowed."
|
| 55 |
-
if use_gds:
|
| 56 |
-
assert cache_dataset, "GDS requires cache_dataset=True."
|
| 57 |
-
assert torch.cuda.is_available(), "GDS requires CUDA to be available."
|
| 58 |
-
|
| 59 |
-
# Dataset configurations
|
| 60 |
-
DATASET_CONFIGS = {
|
| 61 |
-
"ct_rate": {"folder": "CT-RATE", "base_name": "CTRate",
|
| 62 |
-
"extra": {"include_reports": include_reports}},
|
| 63 |
-
"inspect": {"folder": "INSPECT", "base_name": "Inspect",
|
| 64 |
-
"extra": {"include_reports": include_reports}},
|
| 65 |
-
"merlin": {"folder": "MERLIN", "base_name": "Merlin",
|
| 66 |
-
"extra": {"include_reports": include_reports}},
|
| 67 |
-
"nlst": {"folder": "NLST", "base_name": "Nlst"},
|
| 68 |
-
"amos": {"folder": "Amos", "base_name": "Amos"},
|
| 69 |
-
"abdomen_atlas": {"folder": "AbdomenAtlas1.0Mini", "base_name": "AbdomenAtlas",
|
| 70 |
-
"extra": {"include_labels": include_labels}},
|
| 71 |
-
"panorama": {"folder": "PANORAMA", "base_name": "Panorama"},
|
| 72 |
-
"abdomenct_1k": {"folder": "AbdomenCT-1K", "base_name": "AbdomenCT1K",
|
| 73 |
-
"extra": {"include_labels": include_labels}},
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
datasets_list = []
|
| 77 |
-
for ds in datasets:
|
| 78 |
-
if ds.lower() not in DATASET_CONFIGS:
|
| 79 |
-
raise NotImplementedError(f"Dataset {ds} not implemented.")
|
| 80 |
-
|
| 81 |
-
cfg = DATASET_CONFIGS[ds.lower()]
|
| 82 |
-
folder = cfg["folder"]
|
| 83 |
-
extra_args = cfg.get("extra", {})
|
| 84 |
-
|
| 85 |
-
kwargs = {
|
| 86 |
-
"data_dir": os.path.join(data_dir, folder),
|
| 87 |
-
"transform": transform,
|
| 88 |
-
"fraction": fraction,
|
| 89 |
-
**extra_args,
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
base_name = cfg["base_name"]
|
| 93 |
-
class_suffix = "Dataset"
|
| 94 |
-
if cache_dataset:
|
| 95 |
-
class_suffix = "GDSDataset" if use_gds else "PersistentDataset"
|
| 96 |
-
|
| 97 |
-
class_name = f"{base_name}{class_suffix}"
|
| 98 |
-
DatasetClass = getattr(__import__("spectre.data", fromlist=[class_name]), class_name)
|
| 99 |
-
|
| 100 |
-
if cache_dataset:
|
| 101 |
-
kwargs["cache_dir"] = os.path.join(cache_dir, folder)
|
| 102 |
-
if use_gds:
|
| 103 |
-
kwargs["device"] = torch.cuda.current_device()
|
| 104 |
-
|
| 105 |
-
datasets_list.append(DatasetClass(**kwargs))
|
| 106 |
-
|
| 107 |
-
dataset = datasets_list[0] if len(datasets_list) == 1 else ConcatDataset(datasets_list)
|
| 108 |
-
|
| 109 |
-
loader_cls = getattr(data, "ThreadDataLoader" if use_thread else "DataLoader")
|
| 110 |
-
loader_kwargs = {
|
| 111 |
-
"dataset": dataset,
|
| 112 |
-
"batch_size": batch_size,
|
| 113 |
-
"num_workers": num_workers,
|
| 114 |
-
"shuffle": shuffle,
|
| 115 |
-
"drop_last": drop_last,
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
if not use_thread:
|
| 119 |
-
loader_kwargs.update({
|
| 120 |
-
"pin_memory": pin_memory,
|
| 121 |
-
"persistent_workers": persistent_workers
|
| 122 |
-
})
|
| 123 |
-
if collate_fn is not None:
|
| 124 |
-
loader_kwargs["collate_fn"] = collate_fn
|
| 125 |
-
|
| 126 |
-
return loader_cls(**loader_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/distributed.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
import torch.distributed as dist
|
| 4 |
-
|
| 5 |
-
ACCELERATE_IMPORT_ERROR = None
|
| 6 |
-
try:
|
| 7 |
-
from accelerate import Accelerator, DataLoaderConfiguration
|
| 8 |
-
except ImportError as e:
|
| 9 |
-
Accelerator = None # type: ignore
|
| 10 |
-
DataLoaderConfiguration = None # type: ignore
|
| 11 |
-
ACCELERATE_IMPORT_ERROR = e
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def is_enabled() -> bool:
|
| 15 |
-
"""
|
| 16 |
-
Returns:
|
| 17 |
-
True if distributed training is enabled
|
| 18 |
-
"""
|
| 19 |
-
return dist.is_available() and dist.is_initialized()
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def get_global_size() -> int:
|
| 23 |
-
"""
|
| 24 |
-
Returns:
|
| 25 |
-
Number of processes in the distributed group
|
| 26 |
-
"""
|
| 27 |
-
if not is_enabled():
|
| 28 |
-
return 1
|
| 29 |
-
return dist.get_world_size()
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def get_global_rank() -> int:
|
| 33 |
-
"""
|
| 34 |
-
Returns:
|
| 35 |
-
The rank of the current process in the distributed group
|
| 36 |
-
"""
|
| 37 |
-
if not is_enabled():
|
| 38 |
-
return 0
|
| 39 |
-
return dist.get_rank()
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_local_size() -> int:
|
| 43 |
-
"""
|
| 44 |
-
Returns:
|
| 45 |
-
Number of processes on the current machine
|
| 46 |
-
"""
|
| 47 |
-
if not is_enabled():
|
| 48 |
-
return 1
|
| 49 |
-
return int(os.environ.get("LOCAL_SIZE", 1))
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def get_local_rank() -> int:
|
| 53 |
-
"""
|
| 54 |
-
Returns:
|
| 55 |
-
The rank of the current process on the current machine
|
| 56 |
-
"""
|
| 57 |
-
if not is_enabled():
|
| 58 |
-
return 0
|
| 59 |
-
return int(os.environ.get("LOCAL_RANK", 0))
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def init_distributed(cfg):
|
| 63 |
-
"""
|
| 64 |
-
Initialize distributed training.
|
| 65 |
-
"""
|
| 66 |
-
if ACCELERATE_IMPORT_ERROR is not None:
|
| 67 |
-
raise ImportError(
|
| 68 |
-
"Accelerate is required to use init_distributed but not installed. "
|
| 69 |
-
"Please install Accelerate to use this function."
|
| 70 |
-
) from ACCELERATE_IMPORT_ERROR
|
| 71 |
-
|
| 72 |
-
# Initialize accelerator
|
| 73 |
-
dataloader_config = DataLoaderConfiguration(
|
| 74 |
-
non_blocking=cfg.train.pin_memory,
|
| 75 |
-
)
|
| 76 |
-
accelerator = Accelerator(
|
| 77 |
-
gradient_accumulation_steps=cfg.train.grad_accum_steps,
|
| 78 |
-
log_with="wandb" if cfg.train.log_wandb else None,
|
| 79 |
-
dataloader_config=dataloader_config,
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
# Initialize wandb
|
| 83 |
-
if cfg.train.log_wandb:
|
| 84 |
-
accelerator.init_trackers(
|
| 85 |
-
project_name="spectre",
|
| 86 |
-
config={k: v for d in cfg.values() for k, v in d.items()},
|
| 87 |
-
init_kwargs={
|
| 88 |
-
"dir": os.path.join(cfg.train.output_dir, "logs"),
|
| 89 |
-
},
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
return accelerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/lora.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
import loralib as lora
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def add_lora_adapters(
|
| 6 |
-
root_module: nn.Module,
|
| 7 |
-
r: int = 8,
|
| 8 |
-
lora_alpha: int = 32,
|
| 9 |
-
lora_dropout: float = 0.05,
|
| 10 |
-
target_keywords: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
|
| 11 |
-
) -> None:
|
| 12 |
-
"""
|
| 13 |
-
Recursively traverses the model and replaces every `nn.Linear`
|
| 14 |
-
whose name contains one of `target_keywords` with a LoRA-augmented
|
| 15 |
-
linear layer from loralib.
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
for name, child in list(root_module.named_children()):
|
| 19 |
-
# If the child is itself a container, recurse first
|
| 20 |
-
add_lora_adapters(child, r, lora_alpha, lora_dropout, target_keywords)
|
| 21 |
-
|
| 22 |
-
# Replace target linear layers
|
| 23 |
-
if isinstance(child, nn.Linear) and any(k in name for k in target_keywords):
|
| 24 |
-
lora_layer = lora.Linear( # loralib wrapper
|
| 25 |
-
in_features=child.in_features,
|
| 26 |
-
out_features=child.out_features,
|
| 27 |
-
r=r,
|
| 28 |
-
lora_alpha=lora_alpha,
|
| 29 |
-
lora_dropout=lora_dropout,
|
| 30 |
-
bias=child.bias is not None,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
# copy original weights so that behaviour is identical pre-training
|
| 34 |
-
lora_layer.weight.data = child.weight.data.clone()
|
| 35 |
-
if child.bias is not None:
|
| 36 |
-
lora_layer.bias.data = child.bias.data.clone()
|
| 37 |
-
|
| 38 |
-
setattr(root_module, name, lora_layer) # hot-swap!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/masking.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
from typing import Optional, Tuple, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def _random_block_mask(
|
| 8 |
-
size: Tuple[int, int, int],
|
| 9 |
-
num_masks: int,
|
| 10 |
-
min_num_masks_per_block: int = 4,
|
| 11 |
-
max_num_masks_per_block: Optional[int] = None,
|
| 12 |
-
max_attempts_per_block: int = 10,
|
| 13 |
-
generator: Optional[torch.Generator] = None,
|
| 14 |
-
device: Optional[Union[torch.device, str]] = None,
|
| 15 |
-
) -> torch.Tensor:
|
| 16 |
-
"""3D helper: generate a (H, W, D) boolean mask by placing cuboidal blocks.
|
| 17 |
-
|
| 18 |
-
- size: (H, W, D)
|
| 19 |
-
- num_masks: target total number of masked voxels for this image
|
| 20 |
-
- min_num_masks_per_block / max_num_masks_per_block: voxel-range per block
|
| 21 |
-
"""
|
| 22 |
-
H, W, D = size
|
| 23 |
-
total = H * W * D
|
| 24 |
-
num_masks = min(max(0, int(num_masks)), total)
|
| 25 |
-
|
| 26 |
-
if max_num_masks_per_block is None:
|
| 27 |
-
max_num_masks_per_block = max(1, num_masks)
|
| 28 |
-
|
| 29 |
-
mask = torch.zeros((H, W, D), dtype=torch.bool, device=device)
|
| 30 |
-
masked_count = 0
|
| 31 |
-
global_attempts = 0
|
| 32 |
-
|
| 33 |
-
orders = [(0, 1, 2), (1, 2, 0), (2, 0, 1)]
|
| 34 |
-
|
| 35 |
-
# Try to place blocks until we have enough masked voxels or we exceed attempts
|
| 36 |
-
while masked_count < num_masks and global_attempts < max_attempts_per_block:
|
| 37 |
-
global_attempts += 1
|
| 38 |
-
|
| 39 |
-
# choose target voxels for this block
|
| 40 |
-
target_voxels = int(torch.randint(
|
| 41 |
-
min_num_masks_per_block, max_num_masks_per_block + 1, (1,), generator=generator
|
| 42 |
-
).item())
|
| 43 |
-
|
| 44 |
-
found = False
|
| 45 |
-
local_attempts = 0
|
| 46 |
-
while not found and local_attempts < max_attempts_per_block:
|
| 47 |
-
local_attempts += 1
|
| 48 |
-
|
| 49 |
-
# random pick order for dims to reduce bias
|
| 50 |
-
order_idx = int(torch.randint(0, 3, (1,), generator=generator).item())
|
| 51 |
-
order = orders[order_idx]
|
| 52 |
-
|
| 53 |
-
# pick first dimension
|
| 54 |
-
if order[0] == 0:
|
| 55 |
-
h = int(torch.randint(1, min(H, target_voxels) + 1, (1,), generator=generator).item())
|
| 56 |
-
elif order[0] == 1:
|
| 57 |
-
w = int(torch.randint(1, min(W, target_voxels) + 1, (1,), generator=generator).item())
|
| 58 |
-
else:
|
| 59 |
-
d = int(torch.randint(1, min(D, target_voxels) + 1, (1,), generator=generator).item())
|
| 60 |
-
|
| 61 |
-
# progressively choose remaining dims while ensuring feasibility
|
| 62 |
-
try:
|
| 63 |
-
if order[0] == 0:
|
| 64 |
-
# h chosen -> pick w then compute d_needed
|
| 65 |
-
max_w = max(1, min(W, target_voxels // h))
|
| 66 |
-
w = int(torch.randint(1, max_w + 1, (1,), generator=generator).item())
|
| 67 |
-
d_needed = math.ceil(target_voxels / (h * w))
|
| 68 |
-
if d_needed <= D:
|
| 69 |
-
d = max(1, d_needed)
|
| 70 |
-
found = True
|
| 71 |
-
elif order[0] == 1:
|
| 72 |
-
# w chosen -> pick d then compute h_needed
|
| 73 |
-
max_d = max(1, min(D, target_voxels // w))
|
| 74 |
-
d = int(torch.randint(1, max_d + 1, (1,), generator=generator).item())
|
| 75 |
-
h_needed = math.ceil(target_voxels / (d * w))
|
| 76 |
-
if h_needed <= H:
|
| 77 |
-
h = max(1, h_needed)
|
| 78 |
-
found = True
|
| 79 |
-
else:
|
| 80 |
-
# d chosen -> pick h then compute w_needed
|
| 81 |
-
max_h = max(1, min(H, target_voxels // d))
|
| 82 |
-
h = int(torch.randint(1, max_h + 1, (1,), generator=generator).item())
|
| 83 |
-
w_needed = math.ceil(target_voxels / (d * h))
|
| 84 |
-
if w_needed <= W:
|
| 85 |
-
w = max(1, w_needed)
|
| 86 |
-
found = True
|
| 87 |
-
|
| 88 |
-
except ValueError:
|
| 89 |
-
# in case of invalid ranges (defensive); just continue trying
|
| 90 |
-
continue
|
| 91 |
-
|
| 92 |
-
# fallback alternative attempt: try simple factorization heuristics
|
| 93 |
-
if not found:
|
| 94 |
-
# attempt small-to-large factorization
|
| 95 |
-
for hh in range(1, min(H, target_voxels) + 1):
|
| 96 |
-
for ww in range(1, min(W, target_voxels // hh) + 1):
|
| 97 |
-
dd = math.ceil(target_voxels / (hh * ww))
|
| 98 |
-
if dd <= D:
|
| 99 |
-
h, w, d = hh, ww, dd
|
| 100 |
-
found = True
|
| 101 |
-
break
|
| 102 |
-
if found:
|
| 103 |
-
break
|
| 104 |
-
|
| 105 |
-
if not found:
|
| 106 |
-
# couldn't find a fitting block this global attempt; move on
|
| 107 |
-
continue
|
| 108 |
-
|
| 109 |
-
# clamp block dims to volume just in case and ensure at least 1
|
| 110 |
-
h = min(max(1, int(h)), H)
|
| 111 |
-
w = min(max(1, int(w)), W)
|
| 112 |
-
d = min(max(1, int(d)), D)
|
| 113 |
-
|
| 114 |
-
# choose random location so block fits
|
| 115 |
-
x0 = int(torch.randint(0, (H - h) + 1, (1,), generator=generator).item()) if H - h > 0 else 0
|
| 116 |
-
y0 = int(torch.randint(0, (W - w) + 1, (1,), generator=generator).item()) if W - w > 0 else 0
|
| 117 |
-
z0 = int(torch.randint(0, (D - d) + 1, (1,), generator=generator).item()) if D - d > 0 else 0
|
| 118 |
-
|
| 119 |
-
mask[x0 : x0 + h, y0 : y0 + w, z0 : z0 + d] = True
|
| 120 |
-
masked_count = int(mask.sum().item())
|
| 121 |
-
|
| 122 |
-
# If still short, fill remaining voxels at random positions
|
| 123 |
-
if masked_count < num_masks:
|
| 124 |
-
remaining = num_masks - masked_count
|
| 125 |
-
indices = torch.nonzero(~mask, as_tuple=False)
|
| 126 |
-
if indices.numel() > 0:
|
| 127 |
-
perm = torch.randperm(indices.shape[0], generator=generator, device=mask.device)
|
| 128 |
-
pick = indices[perm[:remaining]]
|
| 129 |
-
mask[pick[:, 0], pick[:, 1], pick[:, 2]] = True
|
| 130 |
-
|
| 131 |
-
return mask
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def random_block_mask(
|
| 135 |
-
size: Tuple[int, int, int, int],
|
| 136 |
-
batch_mask_ratio: float = 0.5,
|
| 137 |
-
min_image_mask_ratio: float = 0.1,
|
| 138 |
-
max_image_mask_ratio: float = 0.5,
|
| 139 |
-
min_num_masks_per_block: int = 4,
|
| 140 |
-
max_num_masks_per_block: Optional[int] = None,
|
| 141 |
-
max_attempts_per_block: int = 10,
|
| 142 |
-
generator: Optional[torch.Generator] = None,
|
| 143 |
-
device: Optional[Union[torch.device, str]] = None,
|
| 144 |
-
) -> torch.Tensor:
|
| 145 |
-
"""Create random block masks for 3D volumes only.
|
| 146 |
-
|
| 147 |
-
Args:
|
| 148 |
-
size: (B, H, W, D)
|
| 149 |
-
batch_mask_ratio: fraction of images in the batch to apply masking to
|
| 150 |
-
min_image_mask_ratio / max_image_mask_ratio: per-image mask fraction range
|
| 151 |
-
min_num_masks_per_block / max_num_masks_per_block: voxels per block range
|
| 152 |
-
max_attempts_per_block: attempts to find a fitting block
|
| 153 |
-
generator: optional torch.Generator for reproducibility.
|
| 154 |
-
device: device for returned tensor
|
| 155 |
-
|
| 156 |
-
Returns:
|
| 157 |
-
boolean tensor with shape (B, H, W, D)
|
| 158 |
-
"""
|
| 159 |
-
if len(size) != 4:
|
| 160 |
-
raise ValueError("size must be (B, H, W, D) for 3D masking.")
|
| 161 |
-
|
| 162 |
-
B, H, W, D = size
|
| 163 |
-
|
| 164 |
-
if max_image_mask_ratio < min_image_mask_ratio:
|
| 165 |
-
raise ValueError("max_image_mask_ratio must be >= min_image_mask_ratio.")
|
| 166 |
-
|
| 167 |
-
num_images_masked = int(B * batch_mask_ratio)
|
| 168 |
-
probs = torch.linspace(min_image_mask_ratio, max_image_mask_ratio, num_images_masked + 1).tolist()
|
| 169 |
-
|
| 170 |
-
image_masks = []
|
| 171 |
-
total_voxels = H * W * D
|
| 172 |
-
|
| 173 |
-
for prob_min, prob_max in zip(probs[:-1], probs[1:]):
|
| 174 |
-
# choose number of masked voxels for this image
|
| 175 |
-
u = float(prob_min + (prob_max - prob_min) * torch.rand(1, generator=generator).item())
|
| 176 |
-
num_mask = int(total_voxels * u)
|
| 177 |
-
image_masks.append(
|
| 178 |
-
_random_block_mask(
|
| 179 |
-
size=(H, W, D),
|
| 180 |
-
num_masks=num_mask,
|
| 181 |
-
min_num_masks_per_block=min_num_masks_per_block,
|
| 182 |
-
max_num_masks_per_block=max_num_masks_per_block,
|
| 183 |
-
max_attempts_per_block=max_attempts_per_block,
|
| 184 |
-
generator=generator,
|
| 185 |
-
device=device,
|
| 186 |
-
)
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
# Add non-masked images (all False) to fill the batch
|
| 190 |
-
for _ in range(num_images_masked, B):
|
| 191 |
-
image_masks.append(torch.zeros((H, W, D), dtype=torch.bool, device=device))
|
| 192 |
-
|
| 193 |
-
perm = torch.randperm(B, generator=generator).tolist()
|
| 194 |
-
image_masks = [image_masks[i] for i in perm]
|
| 195 |
-
|
| 196 |
-
return torch.stack(image_masks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/param_groups.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
import torch.nn as nn
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def get_vit_lr_decay_rate(
|
| 5 |
-
name: str,
|
| 6 |
-
llrd_factor: float = 1.0,
|
| 7 |
-
num_layers: int = 12,
|
| 8 |
-
force_is_backbone: bool = False,
|
| 9 |
-
shift: int = 0,
|
| 10 |
-
) -> float:
|
| 11 |
-
"""
|
| 12 |
-
Get the layer-wise learning rate decay (LLRD) rate for a given parameter name.
|
| 13 |
-
|
| 14 |
-
Args:
|
| 15 |
-
name:
|
| 16 |
-
The name of the parameter.
|
| 17 |
-
llrd_factor:
|
| 18 |
-
The decay factor for each layer.
|
| 19 |
-
num_layers:
|
| 20 |
-
The total number of layers in the model.
|
| 21 |
-
force_is_backbone:
|
| 22 |
-
If True, forces the function to treat the parameter as part of the backbone.
|
| 23 |
-
shift:
|
| 24 |
-
An integer to shift the layer ids, useful when combining multiple modules.
|
| 25 |
-
|
| 26 |
-
Returns:
|
| 27 |
-
The learning rate multiplier for the parameter.
|
| 28 |
-
"""
|
| 29 |
-
layer_id = num_layers + 1
|
| 30 |
-
if name.startswith("backbone") or force_is_backbone:
|
| 31 |
-
if (
|
| 32 |
-
".pos_embed" in name
|
| 33 |
-
or ".patch_embed" in name
|
| 34 |
-
or ".patch_proj" in name
|
| 35 |
-
or ".mask_token" in name
|
| 36 |
-
or ".cls_token" in name
|
| 37 |
-
or ".reg_token" in name
|
| 38 |
-
):
|
| 39 |
-
layer_id = 0
|
| 40 |
-
elif ".blocks." in name:
|
| 41 |
-
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + shift
|
| 42 |
-
|
| 43 |
-
return llrd_factor ** (num_layers + 1 - layer_id)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def get_param_groups_with_decay(
|
| 47 |
-
model: nn.Module,
|
| 48 |
-
llrd_factor: float = 1.0,
|
| 49 |
-
patch_embed_lr_mult: float = 1.0,
|
| 50 |
-
projection_head_wd_mult: float = 1.0,
|
| 51 |
-
lora_lr_factor: float = 1.0,
|
| 52 |
-
num_layers: int | None = None,
|
| 53 |
-
):
|
| 54 |
-
|
| 55 |
-
force_is_backbone = False
|
| 56 |
-
shift = 0
|
| 57 |
-
if num_layers is not None:
|
| 58 |
-
num_layers = num_layers
|
| 59 |
-
elif hasattr(model, "n_blocks"):
|
| 60 |
-
num_layers = model.n_blocks
|
| 61 |
-
force_is_backbone = True
|
| 62 |
-
elif hasattr(model, "blocks"):
|
| 63 |
-
num_layers = len(model.blocks)
|
| 64 |
-
force_is_backbone = True
|
| 65 |
-
elif hasattr(model, "backbone") and hasattr(model.backbone, "blocks"):
|
| 66 |
-
num_layers = len(model.backbone.blocks)
|
| 67 |
-
elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "blocks"): # DINO specific
|
| 68 |
-
num_layers = len(model.backbone_student.blocks)
|
| 69 |
-
elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "vit") and hasattr(model.backbone_student.vit, "blocks"): # DINOv2 specific
|
| 70 |
-
num_layers = len(model.backbone_student.vit.blocks)
|
| 71 |
-
elif hasattr(model, "backbone_image") and hasattr(model.backbone_image, "blocks"): # SigLIP specific
|
| 72 |
-
if not hasattr(model, "feature_comb_image") or model.feature_comb_image is None:
|
| 73 |
-
num_layers = len(model.backbone_image.blocks)
|
| 74 |
-
else:
|
| 75 |
-
num_layers = len(model.backbone_image.blocks) + len(model.feature_comb_image.blocks)
|
| 76 |
-
shift = len(model.backbone_image.blocks)
|
| 77 |
-
force_is_backbone = True
|
| 78 |
-
else:
|
| 79 |
-
num_layers = 0
|
| 80 |
-
|
| 81 |
-
all_param_groups = []
|
| 82 |
-
for n, p in model.named_parameters():
|
| 83 |
-
if not p.requires_grad:
|
| 84 |
-
continue
|
| 85 |
-
if not "lora_" in n:
|
| 86 |
-
s = shift if "feature_comb" in n else 0
|
| 87 |
-
llrd_rate = get_vit_lr_decay_rate(
|
| 88 |
-
n, llrd_factor, num_layers, force_is_backbone, s,
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
d = {
|
| 92 |
-
"name": n,
|
| 93 |
-
"params": p,
|
| 94 |
-
"lr_mult": llrd_rate,
|
| 95 |
-
"wd_mult": 1.0,
|
| 96 |
-
}
|
| 97 |
-
|
| 98 |
-
if "head" in n or "projection" in n:
|
| 99 |
-
d["wd_mult"] = projection_head_wd_mult
|
| 100 |
-
|
| 101 |
-
# No weight-decay on biases, norm parameters, layer scale gamma, learned tokens and embeddings
|
| 102 |
-
if n.endswith("bias") or "norm" in n or "gamma" in n or "fourrier_w" in n:
|
| 103 |
-
d["wd_mult"] = 0.0
|
| 104 |
-
|
| 105 |
-
if "patch_embed" in n:
|
| 106 |
-
d["lr_mult"] *= patch_embed_lr_mult
|
| 107 |
-
|
| 108 |
-
else:
|
| 109 |
-
# LoRA parameters
|
| 110 |
-
d = {
|
| 111 |
-
"name": n,
|
| 112 |
-
"params": p,
|
| 113 |
-
"lr_mult": lora_lr_factor,
|
| 114 |
-
"wd_mult": 1.0,
|
| 115 |
-
}
|
| 116 |
-
|
| 117 |
-
all_param_groups.append(d)
|
| 118 |
-
return all_param_groups
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spectre/utils/scheduler.py
DELETED
|
@@ -1,236 +0,0 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
from typing import Optional
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def linear_warmup_schedule(
|
| 9 |
-
step: int,
|
| 10 |
-
warmup_steps: int,
|
| 11 |
-
start_value: float,
|
| 12 |
-
end_value: float,
|
| 13 |
-
) -> float:
|
| 14 |
-
if warmup_steps < 0:
|
| 15 |
-
raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
|
| 16 |
-
if step < 0:
|
| 17 |
-
raise ValueError(f"Current step number {step} can't be negative.")
|
| 18 |
-
if start_value < 0:
|
| 19 |
-
raise ValueError(f"Start value {start_value} can't be negative.")
|
| 20 |
-
if end_value <= 0:
|
| 21 |
-
raise ValueError(f"End value {end_value} can't be non-positive.")
|
| 22 |
-
if start_value > end_value:
|
| 23 |
-
raise ValueError(
|
| 24 |
-
f"Start value {start_value} must be less than or equal to end value {end_value}."
|
| 25 |
-
)
|
| 26 |
-
if step < warmup_steps:
|
| 27 |
-
return start_value + step / warmup_steps * (end_value - start_value)
|
| 28 |
-
else:
|
| 29 |
-
return end_value
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def cosine_schedule(
|
| 33 |
-
step: int,
|
| 34 |
-
max_steps: int,
|
| 35 |
-
start_value: float,
|
| 36 |
-
end_value: float,
|
| 37 |
-
period: Optional[int] = None,
|
| 38 |
-
) -> float:
|
| 39 |
-
"""Use cosine decay to gradually modify start_value to reach target end_value.
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
step:
|
| 43 |
-
Current step number.
|
| 44 |
-
max_steps:
|
| 45 |
-
Total number of steps.
|
| 46 |
-
start_value:
|
| 47 |
-
Starting value.
|
| 48 |
-
end_value:
|
| 49 |
-
Target value.
|
| 50 |
-
period:
|
| 51 |
-
The number of steps over which the cosine function completes a full cycle.
|
| 52 |
-
Defaults to max_steps.
|
| 53 |
-
|
| 54 |
-
Returns:
|
| 55 |
-
Cosine decay value.
|
| 56 |
-
|
| 57 |
-
"""
|
| 58 |
-
if step < 0:
|
| 59 |
-
raise ValueError(f"Current step number {step} can't be negative.")
|
| 60 |
-
if max_steps < 1:
|
| 61 |
-
raise ValueError(f"Total step number {max_steps} must be >= 1.")
|
| 62 |
-
if period is None and step > max_steps:
|
| 63 |
-
warnings.warn(
|
| 64 |
-
f"Current step number {step} exceeds max_steps {max_steps}.",
|
| 65 |
-
category=RuntimeWarning,
|
| 66 |
-
)
|
| 67 |
-
if period is not None and period <= 0:
|
| 68 |
-
raise ValueError(f"Period {period} must be >= 1")
|
| 69 |
-
|
| 70 |
-
decay: float
|
| 71 |
-
if period is not None: # "cycle" based on period, if provided
|
| 72 |
-
decay = (
|
| 73 |
-
end_value
|
| 74 |
-
- (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
|
| 75 |
-
)
|
| 76 |
-
elif max_steps == 1:
|
| 77 |
-
# Avoid division by zero
|
| 78 |
-
decay = end_value
|
| 79 |
-
elif step == max_steps:
|
| 80 |
-
# Special case for Pytorch Lightning which updates LR scheduler also for epoch
|
| 81 |
-
# after last training epoch.
|
| 82 |
-
decay = end_value
|
| 83 |
-
else:
|
| 84 |
-
decay = (
|
| 85 |
-
end_value
|
| 86 |
-
- (end_value - start_value)
|
| 87 |
-
* (np.cos(np.pi * step / (max_steps - 1)) + 1)
|
| 88 |
-
/ 2
|
| 89 |
-
)
|
| 90 |
-
return decay
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def cosine_warmup_schedule(
|
| 94 |
-
step: int,
|
| 95 |
-
max_steps: int,
|
| 96 |
-
start_value: float,
|
| 97 |
-
end_value: float,
|
| 98 |
-
warmup_steps: int,
|
| 99 |
-
warmup_start_value: float,
|
| 100 |
-
warmup_end_value: Optional[float] = None,
|
| 101 |
-
period: Optional[int] = None,
|
| 102 |
-
) -> float:
|
| 103 |
-
"""Use cosine decay to gradually modify start_value to reach target end_value.
|
| 104 |
-
|
| 105 |
-
Uses linear warmup for the first warmup_steps steps.
|
| 106 |
-
|
| 107 |
-
Args:
|
| 108 |
-
step:
|
| 109 |
-
Current step number.
|
| 110 |
-
max_steps:
|
| 111 |
-
Total number of steps.
|
| 112 |
-
start_value:
|
| 113 |
-
Starting value.
|
| 114 |
-
end_value:
|
| 115 |
-
Target value.
|
| 116 |
-
warmup_steps:
|
| 117 |
-
Number of steps for warmup.
|
| 118 |
-
warmup_start_value:
|
| 119 |
-
Starting value for warmup.
|
| 120 |
-
warmup_end_value:
|
| 121 |
-
Target value for warmup. Defaults to start_value.
|
| 122 |
-
period:
|
| 123 |
-
The number of steps over which the cosine function completes a full cycle.
|
| 124 |
-
Defaults to max_steps - warmup_steps.
|
| 125 |
-
Returns:
|
| 126 |
-
Cosine decay value.
|
| 127 |
-
"""
|
| 128 |
-
if warmup_steps < 0:
|
| 129 |
-
raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
|
| 130 |
-
if warmup_steps > max_steps:
|
| 131 |
-
raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps.")
|
| 132 |
-
if step > max_steps:
|
| 133 |
-
warnings.warn(
|
| 134 |
-
f"Current step number {step} exceeds max_steps {max_steps}.",
|
| 135 |
-
category=RuntimeWarning,
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
if warmup_end_value is None:
|
| 139 |
-
warmup_end_value = start_value
|
| 140 |
-
|
| 141 |
-
if step < warmup_steps:
|
| 142 |
-
# Use step + 1 to reach warmup_end_value at end of warmup. This means that the
|
| 143 |
-
# initial warmup_start_value is skipped which is oftentimes desired when setting
|
| 144 |
-
# it to 0 as this would result in no parameter updates.
|
| 145 |
-
return (
|
| 146 |
-
warmup_start_value
|
| 147 |
-
+ (warmup_end_value - warmup_start_value) * (step + 1) / warmup_steps
|
| 148 |
-
)
|
| 149 |
-
else:
|
| 150 |
-
max_steps = max_steps - warmup_steps if period is None else 1
|
| 151 |
-
return cosine_schedule(
|
| 152 |
-
step=step - warmup_steps,
|
| 153 |
-
max_steps=max_steps,
|
| 154 |
-
start_value=start_value,
|
| 155 |
-
end_value=end_value,
|
| 156 |
-
period=period,
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):
|
| 161 |
-
"""Cosine warmup scheduler for learning rate.
|
| 162 |
-
|
| 163 |
-
Args:
|
| 164 |
-
optimizer:
|
| 165 |
-
Optimizer object to schedule the learning rate.
|
| 166 |
-
warmup_epochs:
|
| 167 |
-
Number of warmup epochs or steps.
|
| 168 |
-
max_epochs:
|
| 169 |
-
Total number of training epochs or steps.
|
| 170 |
-
last_epoch:
|
| 171 |
-
The index of last epoch or step.
|
| 172 |
-
start_value:
|
| 173 |
-
Starting learning rate.
|
| 174 |
-
end_value:
|
| 175 |
-
Target learning rate.
|
| 176 |
-
verbose:
|
| 177 |
-
If True, prints a message to stdout for each update.
|
| 178 |
-
warmup_start_value:
|
| 179 |
-
Starting learning rate for warmup.
|
| 180 |
-
warmup_end_value:
|
| 181 |
-
Target learning rate for warmup. Defaults to start_value.
|
| 182 |
-
|
| 183 |
-
Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index
|
| 184 |
-
can be used. The naming follows the PyTorch convention to use `epoch` for the steps
|
| 185 |
-
in the scheduler.
|
| 186 |
-
"""
|
| 187 |
-
|
| 188 |
-
def __init__(
|
| 189 |
-
self,
|
| 190 |
-
optimizer: torch.optim.Optimizer,
|
| 191 |
-
warmup_epochs: int,
|
| 192 |
-
max_epochs: int,
|
| 193 |
-
last_epoch: int = -1,
|
| 194 |
-
start_value: float = 1.0,
|
| 195 |
-
end_value: float = 0.001,
|
| 196 |
-
period: Optional[int] = None,
|
| 197 |
-
verbose: bool = False,
|
| 198 |
-
warmup_start_value: float = 0.0,
|
| 199 |
-
warmup_end_value: Optional[float] = None,
|
| 200 |
-
) -> None:
|
| 201 |
-
self.warmup_epochs = warmup_epochs
|
| 202 |
-
self.max_epochs = max_epochs
|
| 203 |
-
self.start_value = start_value
|
| 204 |
-
self.end_value = end_value
|
| 205 |
-
self.period = period
|
| 206 |
-
self.warmup_start_value = warmup_start_value
|
| 207 |
-
self.warmup_end_value = warmup_end_value
|
| 208 |
-
|
| 209 |
-
super().__init__(
|
| 210 |
-
optimizer=optimizer,
|
| 211 |
-
lr_lambda=self.scale_lr,
|
| 212 |
-
last_epoch=last_epoch,
|
| 213 |
-
verbose=verbose,
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
def scale_lr(self, epoch: int) -> float:
|
| 217 |
-
"""Scale learning rate according to the current epoch number.
|
| 218 |
-
|
| 219 |
-
Args:
|
| 220 |
-
epoch:
|
| 221 |
-
Current epoch number.
|
| 222 |
-
|
| 223 |
-
Returns:
|
| 224 |
-
Scaled learning rate.
|
| 225 |
-
|
| 226 |
-
"""
|
| 227 |
-
return cosine_warmup_schedule(
|
| 228 |
-
step=epoch,
|
| 229 |
-
max_steps=self.max_epochs,
|
| 230 |
-
start_value=self.start_value,
|
| 231 |
-
end_value=self.end_value,
|
| 232 |
-
warmup_steps=self.warmup_epochs,
|
| 233 |
-
warmup_start_value=self.warmup_start_value,
|
| 234 |
-
warmup_end_value=self.warmup_end_value,
|
| 235 |
-
period=self.period,
|
| 236 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|