cclaess commited on
Commit
af79f38
·
verified ·
1 Parent(s): 60e3fe9

Initial commit

Browse files
spectre/models/eomt.py DELETED
@@ -1,230 +0,0 @@
1
- # Adapted from https://github.com/tue-mps/eomt/
2
- from __future__ import annotations
3
-
4
- import math
5
- from typing import Optional, Tuple, Union
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
- from spectre.models.layers import LayerNorm3d
12
-
13
-
14
- class ScaleBlock(nn.Module):
15
- def __init__(
16
- self,
17
- embed_dim: int,
18
- scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
19
- conv1_layer: nn.Module = nn.ConvTranspose3d,
20
- ):
21
- super().__init__()
22
-
23
- self.conv1 = conv1_layer(
24
- embed_dim,
25
- embed_dim,
26
- kernel_size=scale_factors,
27
- stride=scale_factors,
28
- )
29
- self.act = nn.GELU()
30
- self.conv2 = nn.Conv3d(
31
- embed_dim,
32
- embed_dim,
33
- kernel_size=3,
34
- padding=1,
35
- groups=embed_dim,
36
- bias=False,
37
- )
38
- self.norm = LayerNorm3d(embed_dim)
39
-
40
- def forward(self, x):
41
- x = self.conv1(x)
42
- x = self.act(x)
43
- x = self.conv2(x)
44
- x = self.norm(x)
45
-
46
- return x
47
-
48
-
49
- def compute_upscale_stages(patch_size, min_size=4):
50
- # Compute how many times to upscale per dimension
51
- num_stages = []
52
- for size in patch_size:
53
- stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
54
- num_stages.append(stages)
55
- return num_stages
56
-
57
-
58
- class EoMT(nn.Module):
59
- def __init__(
60
- self,
61
- backbone: "VisionTransformer",
62
- num_classes: int,
63
- num_q: int,
64
- num_blocks=4,
65
- masked_attn_enabled=True,
66
- ):
67
- super().__init__()
68
- self.backbone = backbone
69
- self.num_q = num_q
70
- self.num_blocks = num_blocks
71
- self.masked_attn_enabled = masked_attn_enabled
72
-
73
- self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
74
-
75
- self.q = nn.Embedding(num_q, self.backbone.embed_dim)
76
-
77
- self.class_head = nn.Linear(self.backbone.embed_dim, num_classes + 1)
78
-
79
- self.mask_head = nn.Sequential(
80
- nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
81
- nn.GELU(),
82
- nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
83
- nn.GELU(),
84
- nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
85
- )
86
-
87
- patch_size = self.backbone.patch_embed.patch_size
88
- num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
89
-
90
- # Build per-stage scale factors list
91
- max_stages = max(num_upscale_stages)
92
- upscale_blocks = []
93
- for stage_idx in range(max_stages):
94
- # for each dimension, upscale by 2 only if this dimension still has
95
- # remaining upscales at this stage
96
- scale_factors = tuple(
97
- 2 if stage_idx < num_upscale_stages[dim] else 1
98
- for dim in range(len(patch_size))
99
- )
100
- upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
101
- scale_factors=scale_factors))
102
-
103
- self.upscale = nn.Sequential(*upscale_blocks)
104
-
105
- def _predict(self, x: torch.Tensor):
106
- q = x[:, : self.num_q, :]
107
-
108
- class_logits = self.class_head(q)
109
-
110
- x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
111
- x = x.transpose(1, 2).reshape(
112
- x.shape[0], -1, *self.backbone.patch_embed.grid_size
113
- )
114
-
115
- mask_logits = torch.einsum(
116
- "bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
117
- )
118
-
119
- return mask_logits, class_logits
120
-
121
- @torch.compiler.disable
122
- def _disable_attn_mask(self, attn_mask, prob):
123
- if prob < 1:
124
- random_queries = (
125
- torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
126
- > prob
127
- )
128
- attn_mask[
129
- :, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
130
- ][random_queries] = True
131
-
132
- return attn_mask
133
-
134
- def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope=None):
135
- B, N, C = x.shape
136
-
137
- q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
138
- kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
139
- k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
140
- q, k = module.q_norm(q), module.k_norm(k)
141
-
142
- if mask is not None:
143
- mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
144
-
145
- dropout_p = module.attn_drop.p if self.training else 0.0
146
-
147
- if rope is not None:
148
- if isinstance(rope, list):
149
- rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
150
- q, k = module.apply_rotary_pos_emb(q, k, rope)
151
-
152
- if module.fused_attn:
153
- x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
154
- else:
155
- attn = (q @ k.transpose(-2, -1)) * module.scale
156
- if mask is not None:
157
- attn = attn.masked_fill(~mask, float("-inf"))
158
- attn = F.softmax(attn, dim=-1)
159
- attn = module.attn_drop(attn)
160
- x = attn @ v
161
-
162
- x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
163
-
164
- return x
165
-
166
- def forward(self, x: torch.Tensor):
167
- x = self.backbone.patch_embed(x)
168
- x, rope = self.backbone._pos_embed(x)
169
- x = self.backbone.patch_drop(x)
170
- x = self.backbone.norm_pre(x)
171
-
172
- attn_mask = None
173
- mask_logits_per_layer, class_logits_per_layer = [], []
174
-
175
- for i, block in enumerate(self.backbone.blocks):
176
- if i == len(self.backbone.blocks) - self.num_blocks:
177
- x = torch.cat(
178
- (self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
179
- )
180
-
181
- if (
182
- self.masked_attn_enabled
183
- and i >= len(self.backbone.blocks) - self.num_blocks
184
- ):
185
- mask_logits, class_logits = self._predict(self.backbone.norm(x))
186
- mask_logits_per_layer.append(mask_logits)
187
- class_logits_per_layer.append(class_logits)
188
-
189
- attn_mask = torch.ones(
190
- x.shape[0],
191
- x.shape[1],
192
- x.shape[1],
193
- dtype=torch.bool,
194
- device=x.device,
195
- )
196
- interpolated = F.interpolate(
197
- mask_logits,
198
- self.backbone.patch_embed.grid_size,
199
- mode="trilinear",
200
- )
201
- interpolated = interpolated.view(
202
- interpolated.size(0), interpolated.size(1), -1
203
- )
204
- attn_mask[
205
- :,
206
- : self.num_q,
207
- self.num_q + self.backbone.num_prefix_tokens :,
208
- ] = (
209
- interpolated > 0
210
- )
211
- attn_mask = self._disable_attn_mask(
212
- attn_mask,
213
- self.attn_mask_probs[
214
- i - len(self.backbone.blocks) + self.num_blocks
215
- ],
216
- )
217
-
218
- x = x + block.drop_path1(
219
- block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope))
220
- )
221
- x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
222
-
223
- mask_logits, class_logits = self._predict(self.backbone.norm(x))
224
- mask_logits_per_layer.append(mask_logits)
225
- class_logits_per_layer.append(class_logits)
226
-
227
- return (
228
- mask_logits_per_layer,
229
- class_logits_per_layer,
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/models/resnet.py DELETED
@@ -1,726 +0,0 @@
1
- import os
2
- import math
3
- from urllib.parse import urlparse
4
- from typing import Type, Any, Tuple, List, Optional, Union, Dict
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- from spectre.utils import to_ntuple
11
-
12
-
13
- def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
14
- padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
15
- return padding
16
-
17
-
18
- class BasicBlock(nn.Module):
19
- expansion = 1
20
-
21
- def __init__(
22
- self,
23
- inplanes: int,
24
- planes: int,
25
- stride: int = 1,
26
- downsample: Optional[nn.Module] = None,
27
- cardinality: int = 1,
28
- base_width: int = 64,
29
- reduce_first: int = 1,
30
- dilation: int = 1,
31
- first_dilation: Optional[int] = None,
32
- act_layer: Type[nn.Module] = nn.ReLU,
33
- norm_layer: Type[nn.Module] = nn.BatchNorm3d,
34
- ):
35
- """
36
- Args:
37
- inplanes: Input channel dimensionality.
38
- planes: Used to determine output channel dimensionalities.
39
- stride: Stride used in convolution layers.
40
- downsample: Optional downsample layer for residual path.
41
- cardinality: Number of convolution groups.
42
- base_width: Base width used to determine output channel dimensionality.
43
- reduce_first: Reduction factor for first convolution output width of residual blocks.
44
- dilation: Dilation rate for convolution layers.
45
- first_dilation: Dilation rate for first convolution layer.
46
- act_layer: Activation layer.
47
- norm_layer: Normalization layer.
48
- """
49
- super(BasicBlock, self).__init__()
50
-
51
- assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
52
- assert base_width == 64, 'BasicBlock does not support changing base width'
53
- first_planes = planes // reduce_first
54
- outplanes = planes * self.expansion
55
- first_dilation = first_dilation or dilation
56
-
57
- self.conv1 = nn.Conv3d(
58
- inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
59
- dilation=first_dilation, bias=False)
60
- self.bn1 = norm_layer(first_planes)
61
- self.act1 = act_layer()
62
-
63
- self.conv2 = nn.Conv3d(
64
- first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
65
- self.bn2 = norm_layer(outplanes)
66
-
67
- self.act2 = act_layer()
68
- self.downsample = downsample
69
- self.stride = stride
70
- self.dilation = dilation
71
-
72
- def zero_init_last(self):
73
- if getattr(self.bn2, 'weight', None) is not None:
74
- nn.init.zeros_(self.bn2.weight)
75
-
76
- def forward(self, x: torch.Tensor) -> torch.Tensor:
77
- shortcut = x
78
-
79
- x = self.conv1(x)
80
- x = self.bn1(x)
81
- x = self.act1(x)
82
-
83
- x = self.conv2(x)
84
- x = self.bn2(x)
85
-
86
- if self.downsample is not None:
87
- shortcut = self.downsample(shortcut)
88
- x = x + shortcut
89
- x = self.act2(x)
90
-
91
- return x
92
-
93
-
94
- class Bottleneck(nn.Module):
95
- expansion = 4
96
-
97
- def __init__(
98
- self,
99
- inplanes: int,
100
- planes: int,
101
- stride: int = 1,
102
- downsample: Optional[nn.Module] = None,
103
- cardinality: int = 1,
104
- base_width: int = 64,
105
- reduce_first: int = 1,
106
- dilation: int = 1,
107
- first_dilation: Optional[int] = None,
108
- act_layer: Type[nn.Module] = nn.ReLU,
109
- norm_layer: Type[nn.Module] = nn.BatchNorm3d,
110
- ):
111
- """
112
- Args:
113
- inplanes: Input channel dimensionality.
114
- planes: Used to determine output channel dimensionalities.
115
- stride: Stride used in convolution layers.
116
- downsample: Optional downsample layer for residual path.
117
- cardinality: Number of convolution groups.
118
- base_width: Base width used to determine output channel dimensionality.
119
- reduce_first: Reduction factor for first convolution output width of residual blocks.
120
- dilation: Dilation rate for convolution layers.
121
- first_dilation: Dilation rate for first convolution layer.
122
- act_layer: Activation layer.
123
- norm_layer: Normalization layer.
124
- """
125
- super(Bottleneck, self).__init__()
126
-
127
- width = int(math.floor(planes * (base_width / 64)) * cardinality)
128
- first_planes = width // reduce_first
129
- outplanes = planes * self.expansion
130
- first_dilation = first_dilation or dilation
131
-
132
- self.conv1 = nn.Conv3d(inplanes, first_planes, kernel_size=1, bias=False)
133
- self.bn1 = norm_layer(first_planes)
134
- self.act1 = act_layer()
135
-
136
- self.conv2 = nn.Conv3d(
137
- first_planes, width, kernel_size=3, stride=stride,
138
- padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
139
- self.bn2 = norm_layer(width)
140
- self.act2 = act_layer()
141
-
142
- self.conv3 = nn.Conv3d(width, outplanes, kernel_size=1, bias=False)
143
- self.bn3 = norm_layer(outplanes)
144
-
145
- self.act3 = act_layer()
146
- self.downsample = downsample
147
- self.stride = stride
148
- self.dilation = dilation
149
-
150
- def zero_init_last(self):
151
- if getattr(self.bn3, 'weight', None) is not None:
152
- nn.init.zeros_(self.bn3.weight)
153
-
154
- def forward(self, x: torch.Tensor) -> torch.Tensor:
155
- shortcut = x
156
-
157
- x = self.conv1(x)
158
- x = self.bn1(x)
159
- x = self.act1(x)
160
-
161
- x = self.conv2(x)
162
- x = self.bn2(x)
163
- x = self.act2(x)
164
-
165
- x = self.conv3(x)
166
- x = self.bn3(x)
167
-
168
- if self.downsample is not None:
169
- shortcut = self.downsample(shortcut)
170
- x = x + shortcut
171
- x = self.act3(x)
172
-
173
- return x
174
-
175
-
176
- def downsample_conv(
177
- in_channels: int,
178
- out_channels: int,
179
- kernel_size: int,
180
- stride: int = 1,
181
- dilation: int = 1,
182
- first_dilation: Optional[int] = None,
183
- norm_layer: Optional[Type[nn.Module]] = None,
184
- ) -> nn.Module:
185
- norm_layer = norm_layer or nn.BatchNorm3d
186
- kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
187
- first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
188
- p = get_padding(kernel_size, stride, first_dilation)
189
-
190
- return nn.Sequential(*[
191
- nn.Conv3d(
192
- in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
193
- norm_layer(out_channels)
194
- ])
195
-
196
-
197
- def downsample_avg(
198
- in_channels: int,
199
- out_channels: int,
200
- kernel_size: int,
201
- stride: int = 1,
202
- dilation: int = 1,
203
- first_dilation: Optional[int] = None,
204
- norm_layer: Optional[Type[nn.Module]] = None,
205
- ) -> nn.Module:
206
- norm_layer = norm_layer or nn.BatchNorm3d
207
- avg_stride = stride if dilation == 1 else 1
208
- if stride == 1 and dilation == 1:
209
- pool = nn.Identity()
210
- else:
211
- pool = nn.AvgPool3d(2, avg_stride, ceil_mode=True, count_include_pad=False)
212
-
213
- return nn.Sequential(*[
214
- pool,
215
- nn.Conv3d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
216
- norm_layer(out_channels)
217
- ])
218
-
219
-
220
- def make_blocks(
221
- block_fns: Tuple[Union[BasicBlock, Bottleneck]],
222
- channels: Tuple[int, ...],
223
- block_repeats: Tuple[int, ...],
224
- inplanes: int,
225
- reduce_first: int = 1,
226
- output_stride: int = 32,
227
- down_kernel_size: int = 1,
228
- avg_down: bool = False,
229
- **kwargs,
230
- ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
231
- stages = []
232
- feature_info = []
233
- net_num_blocks = sum(block_repeats)
234
- net_block_idx = 0
235
- net_stride = 4
236
- dilation = prev_dilation = 1
237
- for stage_idx, (block_fn, planes, num_blocks) in enumerate(zip(block_fns, channels, block_repeats)):
238
- stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
239
- stride = 1 if stage_idx == 0 else 2
240
- if net_stride >= output_stride:
241
- dilation *= stride
242
- stride = 1
243
- else:
244
- net_stride *= stride
245
-
246
- downsample = None
247
- if stride != 1 or inplanes != planes * block_fn.expansion:
248
- down_kwargs = dict(
249
- in_channels=inplanes,
250
- out_channels=planes * block_fn.expansion,
251
- kernel_size=down_kernel_size,
252
- stride=stride,
253
- dilation=dilation,
254
- first_dilation=prev_dilation,
255
- norm_layer=kwargs.get('norm_layer'),
256
- )
257
- downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
258
-
259
- block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, **kwargs)
260
- blocks = []
261
- for block_idx in range(num_blocks):
262
- downsample = downsample if block_idx == 0 else None
263
- stride = stride if block_idx == 0 else 1
264
- blocks.append(block_fn(
265
- inplanes,
266
- planes,
267
- stride,
268
- downsample,
269
- first_dilation=prev_dilation,
270
- **block_kwargs,
271
- ))
272
- prev_dilation = dilation
273
- inplanes = planes * block_fn.expansion
274
- net_block_idx += 1
275
-
276
- stages.append((stage_name, nn.Sequential(*blocks)))
277
- feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
278
-
279
- return stages, feature_info
280
-
281
-
282
- def feature_take_indices(
283
- num_features: int,
284
- indices: Optional[Union[int, List[int]]] = None,
285
- as_set: bool = False,
286
- ) -> Tuple[List[int], int]:
287
- """ Determine the absolute feature indices to 'take' from.
288
-
289
- Note: This function can be called in forward() so must be torchscript compatible,
290
- which requires some incomplete typing and workaround hacks.
291
-
292
- Args:
293
- num_features: total number of features to select from
294
- indices: indices to select,
295
- None -> select all
296
- int -> select last n
297
- list/tuple of int -> return specified (-ve indices specify from end)
298
- as_set: return as a set
299
-
300
- Returns:
301
- List (or set) of absolute (from beginning) indices, Maximum index
302
- """
303
- if indices is None:
304
- indices = num_features # all features if None
305
-
306
- if isinstance(indices, int):
307
- # convert int -> last n indices
308
- assert 0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})'
309
- take_indices = [num_features - indices + i for i in range(indices)]
310
- else:
311
- take_indices: List[int] = []
312
- for i in indices:
313
- idx = num_features + i if i < 0 else i
314
- assert 0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})'
315
- take_indices.append(idx)
316
-
317
- if not torch.jit.is_scripting() and as_set:
318
- return set(take_indices), max(take_indices)
319
-
320
- return take_indices, max(take_indices)
321
-
322
-
323
- class ResNet(nn.Module):
324
- """ResNet / ResNeXt
325
-
326
- This class implements all variants of ResNet, ResNeXt that
327
- * have > 1 stride in the 3x3 conv layer of bottleneck
328
- * have conv-bn-act ordering
329
-
330
- This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
331
- variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
332
- 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
333
-
334
- ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
335
- * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
336
- * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
337
- * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
338
- * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
339
- * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
340
- * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
341
- * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
342
-
343
- ResNeXt
344
- * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
345
- * same c,d, e, s variants as ResNet can be enabled
346
- """
347
-
348
- def __init__(
349
- self,
350
- block: Union[BasicBlock, Bottleneck],
351
- layers: Tuple[int, ...],
352
- num_classes: int = 1000,
353
- in_chans: int = 1,
354
- output_stride: int = 32,
355
- global_pool: str = 'avg',
356
- cardinality: int = 1,
357
- base_width: int = 64,
358
- stem_width: int = 64,
359
- stem_type: str = '',
360
- replace_stem_pool: bool = False,
361
- block_reduce_first: int = 1,
362
- down_kernel_size: int = 1,
363
- avg_down: bool = False,
364
- channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
365
- act_layer: Type[nn.Module] = nn.ReLU,
366
- norm_layer: Type[nn.Module] = nn.BatchNorm3d,
367
- drop_rate: float = 0.0,
368
- zero_init_last: bool = True,
369
- block_args: Optional[Dict[str, Any]] = None,
370
- ):
371
- """
372
- Args:
373
- block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
374
- layers (List[int]) : number of layers in each block
375
- num_classes (int): number of classification classes (default 1000)
376
- in_chans (int): number of input (color) channels. (default 3)
377
- output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
378
- global_pool (str): Global pooling type. One of 'avg', 'max' (default 'avg')
379
- cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
380
- base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
381
- stem_width (int): number of channels in stem convolutions (default 64)
382
- stem_type (str): The type of stem (default ''):
383
- * '', default - a single 7x7 conv with a width of stem_width
384
- * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
385
- * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
386
- replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
387
- block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
388
- 1 for all archs except senets, where 2 (default 1)
389
- down_kernel_size (int): kernel size of residual block downsample path,
390
- 1x1 for most, 3x3 for senets (default: 1)
391
- avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
392
- act_layer (str, nn.Module): activation layer
393
- norm_layer (str, nn.Module): normalization layer
394
- drop_rate (float): Dropout probability before classifier, for training (default 0.)
395
- zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
396
- block_args (dict): Extra kwargs to pass through to block module
397
- """
398
- super(ResNet, self).__init__()
399
- block_args = block_args or dict()
400
- assert output_stride in (8, 16, 32)
401
- self.num_classes = num_classes
402
- self.drop_rate = drop_rate
403
-
404
- # Stem
405
- deep_stem = 'deep' in stem_type
406
- inplanes = stem_width * 2 if deep_stem else 64
407
- if deep_stem:
408
- stem_chs = (stem_width, stem_width)
409
- if 'tiered' in stem_type:
410
- stem_chs = (3 * (stem_width // 4), stem_width)
411
- self.conv1 = nn.Sequential(*[
412
- nn.Conv3d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
413
- norm_layer(stem_chs[0]),
414
- act_layer(),
415
- nn.Conv3d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
416
- norm_layer(stem_chs[1]),
417
- act_layer(),
418
- nn.Conv3d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
419
- else:
420
- self.conv1 = nn.Conv3d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
421
- self.bn1 = norm_layer(inplanes)
422
- self.act1 = act_layer()
423
- self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
424
-
425
- # Stem pooling. The name 'maxpool' remains for weight compatibility.
426
- if replace_stem_pool:
427
- self.maxpool = nn.Sequential(*filter(None, [
428
- nn.Conv3d(inplanes, inplanes, 3, stride=2, padding=1, bias=False),
429
- norm_layer(inplanes),
430
- act_layer(),
431
- ]))
432
- else:
433
- self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
434
-
435
- # Feature Blocks
436
- block_fns = to_ntuple(len(channels))(block)
437
- stage_modules, stage_feature_info = make_blocks(
438
- block_fns,
439
- channels,
440
- layers,
441
- inplanes,
442
- cardinality=cardinality,
443
- base_width=base_width,
444
- output_stride=output_stride,
445
- reduce_first=block_reduce_first,
446
- avg_down=avg_down,
447
- down_kernel_size=down_kernel_size,
448
- act_layer=act_layer,
449
- norm_layer=norm_layer,
450
- **block_args,
451
- )
452
- for stage in stage_modules:
453
- self.add_module(*stage) # layer1, layer2, etc
454
- self.feature_info.extend(stage_feature_info)
455
-
456
- # Head (Pooling and Classifier)
457
- self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
458
- if global_pool == 'avg':
459
- self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
460
- elif global_pool == 'max':
461
- self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
462
- else:
463
- raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
464
-
465
- self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
466
-
467
- self.init_weights(zero_init_last=zero_init_last)
468
-
469
- @torch.jit.ignore
470
- def init_weights(self, zero_init_last: bool = True):
471
- for n, m in self.named_modules():
472
- if isinstance(m, nn.Conv3d):
473
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
474
- if zero_init_last:
475
- for m in self.modules():
476
- if hasattr(m, 'zero_init_last'):
477
- m.zero_init_last()
478
-
479
- @torch.jit.ignore
480
- def group_matcher(self, coarse: bool = False):
481
- matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
482
- return matcher
483
-
484
- @torch.jit.ignore
485
- def get_classifier(self, name_only: bool = False):
486
- return 'fc' if name_only else self.fc
487
-
488
- def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
489
- self.num_classes = num_classes
490
- if global_pool == 'avg':
491
- self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
492
- elif global_pool == 'max':
493
- self.global_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
494
- else:
495
- raise NotImplementedError('Global pooling type not supported: {}'.format(global_pool))
496
-
497
- self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
498
-
499
- def forward_intermediates(
500
- self,
501
- x: torch.Tensor,
502
- indices: Optional[Union[int, List[int]]] = None,
503
- norm: bool = False,
504
- stop_early: bool = False,
505
- output_fmt: str = 'NCHWD',
506
- intermediates_only: bool = False,
507
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
508
- """ Forward features that returns intermediates.
509
-
510
- Args:
511
- x: Input image tensor
512
- indices: Take last n blocks if int, all if None, select matching indices if sequence
513
- norm: Apply norm layer to compatible intermediates
514
- stop_early: Stop iterating over blocks when last desired intermediate hit
515
- output_fmt: Shape of intermediate feature outputs
516
- intermediates_only: Only return intermediate features
517
- Returns:
518
-
519
- """
520
- assert output_fmt in ('NCHWD',), 'Output shape must be NCHWD.'
521
- intermediates = []
522
- take_indices, max_index = feature_take_indices(5, indices)
523
-
524
- # forward pass
525
- feat_idx = 0
526
- x = self.conv1(x)
527
- x = self.bn1(x)
528
- x = self.act1(x)
529
- if feat_idx in take_indices:
530
- intermediates.append(x)
531
- x = self.maxpool(x)
532
-
533
- layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
534
- if stop_early:
535
- layer_names = layer_names[:max_index]
536
- for n in layer_names:
537
- feat_idx += 1
538
- x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
539
- if feat_idx in take_indices:
540
- intermediates.append(x)
541
-
542
- if intermediates_only:
543
- return intermediates
544
-
545
- return x, intermediates
546
-
547
- def prune_intermediate_layers(
548
- self,
549
- indices: Union[int, List[int]] = 1,
550
- prune_norm: bool = False,
551
- prune_head: bool = True,
552
- ):
553
- """ Prune layers not required for specified intermediates.
554
- """
555
- take_indices, max_index = feature_take_indices(5, indices)
556
- layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
557
- layer_names = layer_names[max_index:]
558
- for n in layer_names:
559
- setattr(self, n, nn.Identity())
560
- if prune_head:
561
- self.reset_classifier(0, '')
562
- return take_indices
563
-
564
- def forward_features(self, x: torch.Tensor) -> torch.Tensor:
565
- x = self.conv1(x)
566
- x = self.bn1(x)
567
- x = self.act1(x)
568
- x = self.maxpool(x)
569
-
570
- x = self.layer1(x)
571
- x = self.layer2(x)
572
- x = self.layer3(x)
573
- x = self.layer4(x)
574
- return x
575
-
576
- def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
577
- x = self.global_pool(x)
578
- x = x.flatten(1)
579
- if self.drop_rate:
580
- x = F.dropout(x, p=float(self.drop_rate), training=self.training)
581
- return x if pre_logits else self.fc(x)
582
-
583
- def forward(self, x: torch.Tensor) -> torch.Tensor:
584
- x = self.forward_features(x)
585
- x = self.forward_head(x)
586
- return x
587
-
588
- @classmethod
589
- def from_pretrained(
590
- cls,
591
- checkpoint_path_or_url: Union[str, os.PathLike],
592
- verbose: bool = True,
593
- **kwargs
594
- ) -> 'ResNet':
595
- """Load pretrained model weights from a local path or a URL."""
596
- model = cls(**kwargs)
597
-
598
- def _is_url(path: str) -> bool:
599
- try:
600
- parsed = urlparse(str(path))
601
- return parsed.scheme in ('http', 'https')
602
- except Exception:
603
- return False
604
-
605
- if _is_url(checkpoint_path_or_url):
606
- if verbose:
607
- print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}")
608
- state_dict = torch.hub.load_state_dict_from_url(
609
- checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose)
610
- else:
611
- local_path = os.fspath(checkpoint_path_or_url)
612
- if not os.path.exists(local_path):
613
- raise FileNotFoundError(f"Checkpoint file not found: {local_path}")
614
- if verbose:
615
- print(f"Loading checkpoint from local path: {local_path}")
616
- state_dict = torch.load(local_path, map_location='cpu', weights_only=False)
617
-
618
- msg = model.load_state_dict(state_dict, strict=False)
619
- if verbose:
620
- print(f"Loaded pretrained weights with msg: {msg}")
621
- return model
622
-
623
-
624
-
625
- def resnet18(
626
- checkpoint_path_or_url: Optional[str] = None,
627
- **kwargs
628
- ) -> ResNet:
629
- """ResNet-18 model with 3D operations.
630
- """
631
- kwargs = dict(
632
- block=BasicBlock,
633
- layers=[2, 2, 2, 2],
634
- cardinality=1,
635
- **kwargs,
636
- )
637
- if checkpoint_path_or_url:
638
- return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
639
- return ResNet(**kwargs)
640
-
641
-
642
- def resnet34(
643
- checkpoint_path_or_url: Optional[str] = None,
644
- **kwargs
645
- ) -> ResNet:
646
- """ResNet-34 model with 3D operations.
647
- """
648
- kwargs = dict(
649
- block=BasicBlock,
650
- layers=[3, 4, 6, 3],
651
- cardinality=1,
652
- **kwargs,
653
- )
654
- if checkpoint_path_or_url:
655
- return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
656
- return ResNet(**kwargs)
657
-
658
-
659
- def resnet50(
660
- checkpoint_path_or_url: Optional[str] = None,
661
- **kwargs
662
- ) -> ResNet:
663
- """ResNet-50 model with 3D operations.
664
- """
665
- kwargs = dict(
666
- block=Bottleneck,
667
- layers=[3, 4, 6, 3],
668
- cardinality=1,
669
- **kwargs,
670
- )
671
- if checkpoint_path_or_url:
672
- return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
673
- return ResNet(**kwargs)
674
-
675
-
676
- def resnet101(
677
- checkpoint_path_or_url: Optional[str] = None,
678
- **kwargs
679
- ) -> ResNet:
680
- """ResNet-101 model with 3D operations.
681
- """
682
- kwargs = dict(
683
- block=Bottleneck,
684
- layers=[3, 4, 23, 3],
685
- cardinality=1,
686
- **kwargs,
687
- )
688
- if checkpoint_path_or_url:
689
- return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
690
- return ResNet(**kwargs)
691
-
692
-
693
- def resnext50(
694
- checkpoint_path_or_url: Optional[str] = None,
695
- **kwargs
696
- ) -> ResNet:
697
- """ResNeXt-50 model with 3D operations.
698
- """
699
- kwargs = dict(
700
- block=Bottleneck,
701
- layers=[3, 4, 6, 3],
702
- cardinality=32,
703
- base_width=4,
704
- **kwargs,
705
- )
706
- if checkpoint_path_or_url:
707
- return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
708
- return ResNet(**kwargs)
709
-
710
-
711
- def resnext101(
712
- checkpoint_path_or_url: Optional[str] = None,
713
- **kwargs
714
- ) -> ResNet:
715
- """ResNeXt-101 model with 3D operations.
716
- """
717
- kwargs = dict(
718
- block=Bottleneck,
719
- layers=[3, 4, 23, 3],
720
- cardinality=32,
721
- base_width=8,
722
- **kwargs,
723
- )
724
- if checkpoint_path_or_url:
725
- return ResNet.from_pretrained(checkpoint_path_or_url, **kwargs)
726
- return ResNet(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/models/seomt.py DELETED
@@ -1,394 +0,0 @@
1
- # Adapted from https://github.com/tue-mps/eomt/
2
-
3
- import math
4
- from typing import Optional, Tuple, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- from spectre.models import VisionTransformer
11
- from spectre.models.layers import LayerNorm3d
12
-
13
-
14
- class ScaleBlock(nn.Module):
15
- def __init__(
16
- self,
17
- embed_dim: int,
18
- scale_factors: Union[int, Tuple[int, int, int]] = (2, 2, 2),
19
- conv1_layer: nn.Module = nn.ConvTranspose3d,
20
- ):
21
- super().__init__()
22
-
23
- self.conv1 = conv1_layer(
24
- embed_dim,
25
- embed_dim,
26
- kernel_size=scale_factors,
27
- stride=scale_factors,
28
- )
29
- self.act = nn.GELU()
30
- self.conv2 = nn.Conv3d(
31
- embed_dim,
32
- embed_dim,
33
- kernel_size=3,
34
- padding=1,
35
- groups=embed_dim,
36
- bias=False,
37
- )
38
- self.norm = LayerNorm3d(embed_dim)
39
-
40
- def forward(self, x):
41
- # print(x.shape)
42
- x = self.conv1(x)
43
- x = self.act(x)
44
- x = self.conv2(x)
45
- x = self.norm(x)
46
-
47
- return x
48
-
49
-
50
- def compute_upscale_stages(patch_size, min_size=4):
51
- # Compute how many times to upscale per dimension
52
- num_stages = []
53
- for size in patch_size:
54
- stages = max(0, int(math.log2(size)) - int(math.log2(min_size)))
55
- num_stages.append(stages)
56
- return num_stages
57
-
58
- def voxel_shuffle_3d(x: torch.Tensor, r: int = 2) -> torch.Tensor:
59
- """
60
- Rearranges channels of a 5D tensor (N, C*r^3, D, H, W) to
61
- (N, C, D*r, H*r, W*r).
62
- """
63
- n, c, d, h, w = x.size()
64
- assert c % (r ** 3) == 0, f"Channels {c} not divisible by r^3={r**3}"
65
- out_c = c // (r ** 3)
66
- x = x.view(n, out_c, r, r, r, d, h, w) # (N, C, r, r, r, D, H, W)
67
- x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() # (N, C, D, r, H, r, W, r)
68
- x = x.view(n, out_c, d * r, h * r, w * r) # (N, C, D*r, H*r, W*r)
69
- return x
70
-
71
-
72
- class MLPUpBlock3D(nn.Module):
73
- """
74
- 2x upsampling via:
75
- - 1x1x1 Conv (per-voxel MLP) expanding channels by 2^3
76
- - 3D voxel shuffle to double D/H/W
77
- - optional norm + activation
78
- """
79
- def __init__(self, channels: int, norm=nn.Identity, activation=nn.GELU):
80
- super().__init__()
81
- self.proj = nn.Conv3d(channels, channels * 8, kernel_size=1, bias=True)
82
- self.norm = norm(channels) if norm is not None else nn.Identity()
83
- self.act = activation() if activation is not None else nn.Identity()
84
-
85
- def forward(self, x: torch.Tensor) -> torch.Tensor:
86
- x = self.proj(x) # (N, C*8, D, H, W)
87
- x = voxel_shuffle_3d(x, 2) # (N, C, 2D, 2H, 2W)
88
- x = self.norm(x)
89
- x = self.act(x)
90
- return x
91
-
92
-
93
- class SmolMLPDecoder3D(nn.Module):
94
- """
95
- Simple decoder with two MLP upsampling layers (2x each) for total 4x upsampling.
96
- """
97
- def __init__(
98
- self,
99
- in_channels: int,
100
- out_channels: int,
101
- norm=LayerNorm3d,
102
- activation=nn.GELU,
103
- ):
104
- super().__init__()
105
- self.up1 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
106
- self.up2 = MLPUpBlock3D(in_channels, norm=norm, activation=activation)
107
- self.head = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True)
108
-
109
- def forward(self, x: torch.Tensor) -> torch.Tensor:
110
- x = self.up1(x) # 2x
111
- x = self.up2(x) # another 2x => 4x total
112
- x = self.head(x) # map to desired output channels
113
- return x
114
-
115
- class SpatialMLPDecoder3D(nn.Module):
116
- """
117
- Per-voxel MLP that expands logits from (B,Q,H,W,D) to (B,Q,H*s,W*s,D*s)
118
- by predicting a learned s^3 block per voxel, then rearranging spatially.
119
- """
120
- def __init__(self, num_classes: int, upscale_factor: int = 4, hidden_mul: int = 4):
121
- super().__init__()
122
- self.num_classes = num_classes
123
- self.s = upscale_factor
124
- hidden = hidden_mul * num_classes
125
- self.mlp = nn.Sequential(
126
- nn.Linear(num_classes, hidden),
127
- nn.GELU(),
128
- nn.Linear(hidden, num_classes * (upscale_factor ** 3)),
129
- )
130
-
131
- def forward(self, x: torch.Tensor) -> torch.Tensor:
132
- # x: (B, Q, H, W, D)
133
- B, Q, H, W, D = x.shape
134
- s = self.s
135
- x = x.permute(0, 2, 3, 4, 1).contiguous().view(B * H * W * D, Q) # (BHW D, Q)
136
- x = self.mlp(x) # (BHW D, Q*s^3)
137
- x = x.view(B, H, W, D, Q, s, s, s) # (B,H,W,D,Q,s,s,s)
138
- x = x.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous() # (B,Q,H,s,W,s,D,s)
139
- x = x.view(B, Q, H * s, W * s, D * s)
140
- return x
141
-
142
-
143
- class SimpleConvDecoder3D(nn.Module):
144
- """
145
- Depthwise ConvTranspose3d upsampler:
146
- - Assumes input & output channels == num_classes
147
- - Uses groups=num_classes for class-wise deconvolution
148
- """
149
- def __init__(self, num_classes: int, upscale_factor: int = 4):
150
- super().__init__()
151
- s = upscale_factor
152
- self.deconv = nn.ConvTranspose3d(
153
- in_channels=num_classes,
154
- out_channels=num_classes,
155
- kernel_size=s,
156
- stride=s,
157
- padding=0,
158
- output_padding=0,
159
- groups=num_classes,
160
- bias=True,
161
- )
162
-
163
- def forward(self, x: torch.Tensor) -> torch.Tensor:
164
- # x: (B, Q, H, W, D)
165
- return self.deconv(x)
166
-
167
-
168
- class SEoMT(nn.Module):
169
- def __init__(
170
- self,
171
- backbone: VisionTransformer,
172
- num_classes: int,
173
- # num_q: int,
174
- num_blocks=4,
175
- masked_attn_enabled=True,
176
- return_only_final_layer=False,
177
- upscale_output=True,
178
- for_nnunet=False,
179
- decoder=False,
180
- ):
181
- super().__init__()
182
- self.backbone = backbone
183
- self.num_q = num_classes
184
- self.num_blocks = num_blocks
185
- self.masked_attn_enabled = masked_attn_enabled
186
- self.return_only_final_layer = return_only_final_layer
187
- self.upscale_output = upscale_output
188
- self.register_buffer("attn_mask_probs", torch.ones(num_blocks))
189
- self.for_nnunet = for_nnunet
190
- self.q = nn.Embedding(num_classes, self.backbone.embed_dim)
191
-
192
-
193
- self.mask_head = nn.Sequential(
194
- nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
195
- nn.GELU(),
196
- nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
197
- nn.GELU(),
198
- nn.Linear(self.backbone.embed_dim, self.backbone.embed_dim),
199
- )
200
-
201
- patch_size = self.backbone.patch_embed.patch_size
202
- num_upscale_stages = compute_upscale_stages(patch_size, min_size=4)
203
-
204
- # Build per-stage scale factors list
205
- max_stages = max(num_upscale_stages)
206
- upscale_blocks = []
207
- for stage_idx in range(max_stages):
208
- # for each dimension, upscale by 2 only if this dimension still has
209
- # remaining upscales at this stage
210
- scale_factors = tuple(
211
- 2 if stage_idx < num_upscale_stages[dim] else 1
212
- for dim in range(len(patch_size))
213
- )
214
- upscale_blocks.append(ScaleBlock(self.backbone.embed_dim,
215
- scale_factors=scale_factors))
216
-
217
- self.upscale = nn.Sequential(*upscale_blocks)
218
-
219
- def _predict(self, x: torch.Tensor, stage: int = None):
220
- q = x[:, : self.num_q, :]
221
- # print(stage)
222
- # class_logits = self.class_head(q)
223
- x = x[:, self.num_q + self.backbone.num_prefix_tokens :, :]
224
- x = x.transpose(1, 2).reshape(
225
- x.shape[0], -1, *self.backbone.patch_embed.grid_size
226
- )
227
- mask_logits = torch.einsum(
228
- "bqc, bchwd -> bqhwd", self.mask_head(q), self.upscale(x)
229
- )
230
-
231
-
232
- return mask_logits
233
-
234
- @torch.compiler.disable
235
- def _disable_attn_mask(self, attn_mask, prob):
236
- if prob < 1:
237
- random_queries = (
238
- torch.rand(attn_mask.shape[0], self.num_q, device=attn_mask.device)
239
- > prob
240
- )
241
- attn_mask[
242
- :, : self.num_q, self.num_q + self.backbone.num_prefix_tokens :
243
- ][random_queries] = True
244
-
245
- return attn_mask
246
-
247
- def _attn(self, module: 'Attention', x: torch.Tensor, mask: Optional[torch.Tensor], rope = None):
248
- B, N, C = x.shape
249
-
250
- q = module.q(x).reshape(B, N, module.num_heads, module.head_dim).permute(0, 2, 1, 3)
251
- kv = module.kv(x).reshape(B, N, 2, module.num_heads, module.head_dim)
252
- k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)
253
- q, k = module.q_norm(q), module.k_norm(k)
254
-
255
- if mask is not None:
256
- mask = mask[:, None, ...].expand(-1, module.num_heads, -1, -1)
257
-
258
- dropout_p = module.attn_drop.p if self.training else 0.0
259
-
260
- if rope is not None:
261
- if isinstance(rope, list):
262
- rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2))
263
- q, k = module.apply_rotary_pos_emb(q, k, rope)
264
-
265
- if module.fused_attn:
266
- x = F.scaled_dot_product_attention(q, k, v, mask, dropout_p)
267
- else:
268
- attn = (q @ k.transpose(-2, -1)) * module.scale
269
- if mask is not None:
270
- attn = attn.masked_fill(~mask, float("-inf"))
271
- attn = F.softmax(attn, dim=-1)
272
- attn = module.attn_drop(attn)
273
- x = attn @ v
274
-
275
- x = module.proj_drop(module.proj(x.transpose(1, 2).reshape(B, N, C)))
276
-
277
- return x
278
-
279
- def forward(self, x: torch.Tensor):
280
-
281
- if self.for_nnunet: # swap data order, will be incoming at czyx - cxyz
282
- x = x.permute(0, 1, 3, 4, 2).contiguous()
283
-
284
- self.backbone.patch_embed.set_input_size(x.shape[2:])
285
- x = self.backbone.patch_embed(x)
286
- x, rope = self.backbone._pos_embed(x)
287
- x = self.backbone.patch_drop(x)
288
- x = self.backbone.norm_pre(x)
289
- attn_mask = None
290
- mask_logits_per_layer = []
291
-
292
- for i, block in enumerate(self.backbone.blocks):
293
- if i == len(self.backbone.blocks) - self.num_blocks:
294
- x = torch.cat(
295
- (self.q.weight[None, :, :].expand(x.shape[0], -1, -1), x), dim=1
296
- )
297
-
298
- if (
299
- self.masked_attn_enabled
300
- and i >= len(self.backbone.blocks) - self.num_blocks
301
- ):
302
- mask_logits = self._predict(self.backbone.norm(x))
303
-
304
- if self.for_nnunet:
305
- # swap back to czyx
306
-
307
- if self.upscale_output:
308
- # Upscale to original input size / stage
309
-
310
- stage = len(self.backbone.blocks) -i
311
- if stage is not None:
312
- input_size = tuple(
313
- int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**(stage))
314
- for dim in range(len(self.backbone.patch_embed.patch_size))
315
- )
316
-
317
- mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
318
- else:
319
- mask_logits_per_layer.append(mask_logits.permute(0, 1, 4, 2, 3).contiguous())
320
- else:
321
- mask_logits_per_layer.append(mask_logits)
322
-
323
- attn_mask = torch.ones(
324
- x.shape[0],
325
- x.shape[1],
326
- x.shape[1],
327
- dtype=torch.bool,
328
- device=x.device,
329
- )
330
- interpolated = F.interpolate(
331
- mask_logits,
332
- self.backbone.patch_embed.grid_size,
333
- mode="trilinear",
334
- )
335
- interpolated = interpolated.view(
336
- interpolated.size(0), interpolated.size(1), -1
337
- )
338
- attn_mask[
339
- :,
340
- : self.num_q,
341
- self.num_q + self.backbone.num_prefix_tokens :,
342
- ] = (
343
- interpolated > 0
344
- )
345
- attn_mask = self._disable_attn_mask(
346
- attn_mask,
347
- self.attn_mask_probs[
348
- i - len(self.backbone.blocks) + self.num_blocks
349
- ],
350
- )
351
- x = x + block.drop_path1(
352
- block.ls1(self._attn(block.attn, block.norm1(x), attn_mask, rope=rope))
353
- )
354
- x = x + block.drop_path2(block.ls2(block.mlp(block.norm2(x))))
355
-
356
- mask_logits = self._predict(self.backbone.norm(x))
357
- if self.for_nnunet:
358
- input_size = tuple(
359
- int(self.backbone.patch_embed.patch_size[dim] * self.backbone.patch_embed.grid_size[dim] / 2**0)
360
- for dim in range(len(self.backbone.patch_embed.patch_size))
361
- )
362
- mask_logits_per_layer.append(F.interpolate(mask_logits, input_size, mode="trilinear").permute(0, 1, 4, 2, 3).contiguous())
363
- else:
364
- mask_logits_per_layer.append(mask_logits)
365
-
366
- if self.for_nnunet:
367
- # return in reversed order for deep supervision
368
- mask_logits_per_layer = mask_logits_per_layer[::-1]
369
- return mask_logits_per_layer if not self.return_only_final_layer else mask_logits_per_layer[0]
370
-
371
-
372
- if __name__ == "__main__":
373
- from spectre.models import vit_large_patch16_128
374
-
375
- model = SEoMT(
376
- backbone=vit_large_patch16_128(pos_embed='rope',
377
- rope_kwargs={
378
- "base": 1000.0, # works for most 3D models
379
- },),
380
- num_classes=4,
381
- num_blocks=4,
382
- masked_attn_enabled=True,
383
- return_only_final_layer=True,
384
- for_nnunet=True,
385
- upscale_output=True,
386
- decoder=False,
387
- )
388
- # print number of parameters
389
- print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
390
-
391
- x = torch.randn(2, 1, 64, 128, 128)
392
- out = model(x)
393
- for o in out:
394
- print(o.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/models/upsample_anything.py DELETED
@@ -1,319 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from torch.optim.lr_scheduler import LambdaLR
6
-
7
-
8
- def UPA(hr_image, lr_volume, device="cuda", use_amp=True):
9
- """
10
- hr_image: numpy or torch [C,Hh,Wh,Dh]
11
- lr_volume: torch [1,C,Hl,Wl,Dl]
12
- """
13
-
14
- hr = torch.as_tensor(hr_image).unsqueeze(0).float().to(device)
15
-
16
- _, _, Hh, Wh, Dh = hr.shape
17
- _, _, Hl, Wl, Dl = lr_volume.shape
18
- scale = Hh // Hl
19
- assert Wh // Wl == scale and Dh // Dl == scale, "Inconsistent scale factors"
20
-
21
- lr_volume = lr_volume.to(device).float()
22
- lr = F.interpolate(hr, scale_factor=1/scale, mode="trilinear", align_corners=False)
23
-
24
- model = LearnablePixelwiseAnisoJBU3D(
25
- Hl, Wl, Dl, scale=scale
26
- ).to(device)
27
-
28
- model.train()
29
- opt = torch.optim.Adam(model.parameters(), lr=1e-1)
30
- max_steps = 350
31
- gamma = (1e-9 / 1e-1) ** (1.0 / max_steps)
32
- scheduler = LambdaLR(opt, lr_lambda=lambda step: gamma ** step)
33
- scaler = torch.amp.GradScaler(device=device, enabled=use_amp)
34
-
35
- for step in range(max_steps):
36
- opt.zero_grad(set_to_none=True)
37
- with torch.amp.autocast(device_type=device, enabled=use_amp):
38
- pred = model(lr, hr)
39
- loss = F.l1_loss(pred, hr)
40
-
41
- scaler.scale(loss).backward()
42
- scaler.step(opt)
43
- scaler.update()
44
- scheduler.step()
45
-
46
- if step % 50 == 0:
47
- print(f"step {step}: loss={loss.item():.5f}")
48
-
49
- model.eval()
50
- with torch.inference_mode(), \
51
- torch.amp.autocast(device_type=device, enabled=use_amp, dtype=torch.float16):
52
- out = model(lr_volume, hr)
53
-
54
- return out
55
-
56
-
57
- @torch.no_grad()
58
- def _build_offsets_3d(R_max: int, device):
59
- offs = torch.arange(-R_max, R_max + 1, device=device)
60
- dX, dY, dZ = torch.meshgrid(offs, offs, offs, indexing="ij")
61
- return (
62
- dX.reshape(-1),
63
- dY.reshape(-1),
64
- dZ.reshape(-1),
65
- ) # [K]
66
-
67
-
68
- def gather_lr_scalar_3d(map_lr, Ui, Vi, Wi):
69
- """
70
- map_lr: [1,1,Hl,Wl,Dl] or [Hl,Wl,Dl]
71
- Ui,Vi,Wi: [Bn,Hh,Wh,Dh]
72
- """
73
- Hl, Wl, Dl = map_lr.shape[-3:]
74
- flat = Hl * Wl * Dl
75
- idx = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
76
- t = map_lr.view(flat)
77
- vals = t.index_select(0, idx)
78
- return vals.view(Ui.shape)
79
-
80
-
81
- def gs_jbu_aniso_noparent_3d(
82
- feat_lr, # [1,C,Hl,Wl,Dl]
83
- guide_hr, # [1,G,Hh,Wh,Dh]
84
- scale,
85
- sigma_x_map,
86
- sigma_y_map,
87
- sigma_z_map,
88
- sigma_r_map,
89
- R_max=3,
90
- alpha_dyn=2.0,
91
- C_chunk=64,
92
- Nn_chunk=125,
93
- ):
94
- _, C, Hl, Wl, Dl = feat_lr.shape
95
- _, _, Hh, Wh, Dh = guide_hr.shape
96
- device = feat_lr.device
97
- dtype_feat = feat_lr.dtype
98
-
99
- # HR grid
100
- x = torch.arange(Hh, device=device, dtype=torch.float32)
101
- y = torch.arange(Wh, device=device, dtype=torch.float32)
102
- z = torch.arange(Dh, device=device, dtype=torch.float32)
103
- X, Y, Z = torch.meshgrid(x, y, z, indexing="ij")
104
-
105
- u = (X + 0.5) / scale - 0.5
106
- v = (Y + 0.5) / scale - 0.5
107
- w = (Z + 0.5) / scale - 0.5
108
-
109
- uc = torch.round(u).clamp(0, Hl - 1).long()
110
- vc = torch.round(v).clamp(0, Wl - 1).long()
111
- wc = torch.round(w).clamp(0, Dl - 1).long()
112
-
113
- # Dynamic radius
114
- sigma_eff = torch.maximum(
115
- sigma_x_map,
116
- torch.maximum(sigma_y_map, sigma_z_map),
117
- )
118
- sigma_eff_hr = F.interpolate(
119
- sigma_eff, (Hh, Wh, Dh), mode="trilinear", align_corners=False
120
- )
121
- # sigma_eff_hr = sigma_eff_hr.squeeze(0).squeeze(0)
122
- R_map = torch.ceil(alpha_dyn * sigma_eff_hr).clamp(1, R_max).long()
123
-
124
- dX_all, dY_all, dZ_all = _build_offsets_3d(R_max, device)
125
-
126
- num = torch.zeros(C, Hh, Wh, Dh, device=device, dtype=torch.float32)
127
- den = torch.zeros(Hh, Wh, Dh, device=device, dtype=torch.float32)
128
- m = torch.full((Hh, Wh, Dh), -1e9, device=device, dtype=torch.float32)
129
-
130
- feat_flat = feat_lr[0].permute(1, 2, 3, 0).reshape(-1, C)
131
- guide_lr = F.interpolate(
132
- guide_hr, (Hl, Wl, Dl), mode="trilinear", align_corners=False
133
- )
134
-
135
- for n0 in range(0, len(dX_all), Nn_chunk):
136
- dX = dX_all[n0:n0+Nn_chunk][:, None, None, None]
137
- dY = dY_all[n0:n0+Nn_chunk][:, None, None, None]
138
- dZ = dZ_all[n0:n0+Nn_chunk][:, None, None, None]
139
-
140
- Ui = torch.clamp(uc.unsqueeze(0) + dX, 0, Hl - 1)
141
- Vi = torch.clamp(vc.unsqueeze(0) + dY, 0, Wl - 1)
142
- Wi = torch.clamp(wc.unsqueeze(0) + dZ, 0, Dl - 1)
143
-
144
- # mask = (dX**2 + dY**2 + dZ**2 <= R_map[None, ...] ** 2)
145
- mask = (dX**2 + dY**2 + dZ**2 <= R_map**2).squeeze(0).squeeze(0)
146
-
147
- cx = (Ui.float() + 0.5) * scale - 0.5
148
- cy = (Vi.float() + 0.5) * scale - 0.5
149
- cz = (Wi.float() + 0.5) * scale - 0.5
150
-
151
- dx = X.unsqueeze(0) - cx
152
- dy = Y.unsqueeze(0) - cy
153
- dz = Z.unsqueeze(0) - cz
154
-
155
- sx = gather_lr_scalar_3d(sigma_x_map, Ui, Vi, Wi).clamp_min(1e-6)
156
- sy = gather_lr_scalar_3d(sigma_y_map, Ui, Vi, Wi).clamp_min(1e-6)
157
- sz = gather_lr_scalar_3d(sigma_z_map, Ui, Vi, Wi).clamp_min(1e-6)
158
- sr = gather_lr_scalar_3d(sigma_r_map, Ui, Vi, Wi).clamp_min(1e-6)
159
-
160
- log_ws = (
161
- -(dx**2)/(2*sx**2)
162
- -(dy**2)/(2*sy**2)
163
- -(dz**2)/(2*sz**2)
164
- )
165
-
166
- diff2 = 0.0
167
- for g in range(guide_hr.shape[1]):
168
- g0 = gather_lr_scalar_3d(guide_lr[0, g], Ui, Vi, Wi)
169
- diff2 += (guide_hr[0, g] - g0) ** 2
170
-
171
- log_wr = -diff2 / (2 * sr**2 + 1e-8)
172
- log_w = torch.where(mask, log_ws + log_wr, -1e9)
173
-
174
- m_chunk = log_w.max(dim=0).values
175
- m_new = torch.maximum(m, m_chunk)
176
-
177
- scale_old = torch.exp(m - m_new)
178
- num *= scale_old
179
- den *= scale_old
180
-
181
- w = torch.exp(log_w - m_new)
182
- den += w.sum(0)
183
-
184
- idx_flat = (Ui * Wl * Dl + Vi * Dl + Wi).reshape(-1)
185
-
186
- for c0 in range(0, C, C_chunk):
187
- c1 = min(c0 + C_chunk, C)
188
- f = feat_flat.index_select(0, idx_flat)[:, c0:c1]
189
- f = f.view(w.shape + (c1 - c0,))
190
- num[c0:c1] += (f * w[..., None]).sum(0).permute(3, 0, 1, 2)
191
-
192
- m = m_new
193
-
194
- out = (num / den.clamp_min(1e-8)).unsqueeze(0)
195
- return out.to(dtype_feat)
196
-
197
-
198
- class LearnablePixelwiseAnisoJBU3D(nn.Module):
199
- def __init__(
200
- self,
201
- Hl,
202
- Wl,
203
- Dl,
204
- scale,
205
- init_sigma=1.5,
206
- init_sigma_r=0.1,
207
- R_max=3,
208
- alpha_dyn=2.0,
209
- ):
210
- super().__init__()
211
- self.scale = scale
212
- self.R_max = R_max
213
- self.alpha_dyn = alpha_dyn
214
-
215
- self.sx_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
216
- self.sy_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
217
- self.sz_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma)))
218
- self.sr_raw = nn.Parameter(torch.full((1,1,Hl,Wl,Dl), math.log(init_sigma_r)))
219
-
220
- def forward(self, feat_lr, guide_hr):
221
- return gs_jbu_aniso_noparent_3d(
222
- feat_lr,
223
- guide_hr,
224
- self.scale,
225
- torch.exp(self.sx_raw),
226
- torch.exp(self.sy_raw),
227
- torch.exp(self.sz_raw),
228
- torch.exp(self.sr_raw),
229
- R_max=self.R_max,
230
- alpha_dyn=self.alpha_dyn,
231
- )
232
-
233
-
234
- if __name__ == "__main__":
235
- import argparse
236
-
237
- import numpy as np
238
- import nibabel as nib
239
- import monai.transforms as transforms
240
-
241
- parser = argparse.ArgumentParser()
242
- parser.add_argument("--image_path", type=str, required=True)
243
- parser.add_argument("--mask_path", type=str, required=True)
244
- parser.add_argument("--device", type=str, default="cuda")
245
- parser.add_argument("--use_amp", action="store_true")
246
- args = parser.parse_args()
247
-
248
- transform = transforms.Compose([
249
- transforms.LoadImaged(keys=("image", "mask")),
250
- transforms.EnsureChannelFirstd(keys=("image", "mask"), channel_dim="no_channel"),
251
- transforms.ScaleIntensityRanged(
252
- keys=("image",),
253
- a_min=-150,
254
- a_max=250,
255
- b_min=0.0,
256
- b_max=1.0,
257
- clip=True,
258
- ),
259
- transforms.Orientationd(keys=("image", "mask"), axcodes="RAS"),
260
- transforms.RandWeightedCropd(
261
- keys=("image", "mask"),
262
- w_key="mask",
263
- spatial_size=(128, 128, 64),
264
- num_samples=1,
265
- ),
266
- transforms.CopyItemsd(keys=("mask"), times=1, names=("mask_low_res")),
267
- transforms.Resized(keys=("mask_low_res"), spatial_size=(16, 16, 8), mode="nearest", align_corners=False)
268
- ])
269
- sample = transform({
270
- "image": args.image_path,
271
- "mask": args.mask_path,
272
- })[0]
273
-
274
- nib.save(
275
- nib.Nifti1Image(
276
- (F.interpolate(sample["mask_low_res"].unsqueeze(0), size=(128, 128, 64), mode="nearest").squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8)),
277
- affine=np.eye(4),
278
- ),
279
- "mask_low_res_upscaled.nii.gz",
280
- )
281
-
282
- sample["mask_low_res"] = F.one_hot(
283
- sample["mask_low_res"].long().squeeze(0), num_classes=4,
284
- ).permute(3, 0, 1, 2).unsqueeze(0).float()
285
-
286
- print(sample["mask_low_res"].shape)
287
-
288
- mask_out = UPA(
289
- sample["image"],
290
- sample["mask_low_res"],
291
- device=args.device,
292
- use_amp=args.use_amp,
293
- )
294
-
295
- mask_out = mask_out.argmax(dim=1, keepdim=True)
296
-
297
- nib.save(
298
- nib.Nifti1Image(
299
- (sample["image"] * 255).squeeze(0).cpu().numpy().astype(np.uint8),
300
- affine=np.eye(4),
301
- ),
302
- "image.nii.gz",
303
- )
304
- nib.save(
305
- nib.Nifti1Image(
306
- sample["mask"].squeeze(0).cpu().numpy().astype(np.uint8),
307
- affine=np.eye(4),
308
- ),
309
- "mask.nii.gz",
310
- )
311
- torch.save(mask_out.squeeze(0).squeeze(0).cpu(), "upsampled_mask.pt")
312
- nib.save(
313
- nib.Nifti1Image(
314
- mask_out.squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8),
315
- affine=np.eye(4),
316
- ),
317
- "upsampled_mask.nii.gz",
318
- )
319
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/checkpointing.py DELETED
@@ -1,238 +0,0 @@
1
- import os
2
- import random
3
- import warnings
4
- from typing import Optional, Any
5
-
6
- import torch
7
- import numpy as np
8
- import torch.distributed as dist
9
-
10
-
11
- def _get_local_rng_state() -> dict:
12
- """Return a picklable dict with local RNG states (cpu & cuda, numpy, random)."""
13
- state = {
14
- "torch": torch.get_rng_state().cpu(),
15
- "numpy": np.random.get_state(),
16
- "random": random.getstate(),
17
- }
18
-
19
- if torch.cuda.is_available():
20
- # make sure CUDA states are stored on CPU so they are picklable
21
- cuda_states = [s.cpu() for s in torch.cuda.get_rng_state_all()]
22
- state["cuda"] = cuda_states
23
- else:
24
- state["cuda"] = None
25
-
26
- return state
27
-
28
-
29
- def _set_local_rng_state(state: dict) -> None:
30
- """Set local RNG states from the dict produced by _get_local_rng_state()."""
31
- if state is None:
32
- return
33
-
34
- if "torch" in state and state["torch"] is not None:
35
- torch.set_rng_state(state["torch"])
36
- if "cuda" in state and state["cuda"] is not None and torch.cuda.is_available():
37
- try:
38
- # move back to CUDA tensors for this process and set them
39
- cuda_states = [s.cuda() for s in state["cuda"]]
40
- torch.cuda.set_rng_state_all(cuda_states)
41
- except Exception:
42
- # fallback: try setting per-device RNG if set_rng_state_all fails
43
- for i, s in enumerate(state["cuda"]):
44
- try:
45
- torch.cuda.set_rng_state(s.cuda(), device=i)
46
- except Exception:
47
- # ignore if device mismatch
48
- pass
49
-
50
- if "numpy" in state and state["numpy"] is not None:
51
- np.random.set_state(state["numpy"])
52
- if "random" in state and state["random"] is not None:
53
- random.setstate(state["random"])
54
-
55
-
56
- def save_state(ckpt_path: str, epoch: Optional[int] = None, **named_objects: Any) -> None:
57
- """
58
- Save a checkpoint that includes:
59
- - epoch (optional)
60
- - state_dicts for provided named_objects
61
- - rng_states: list of per-rank RNG dictionaries (one entry per world rank)
62
-
63
- If torch.distributed is initialized the RNG states from all ranks are gathered and
64
- stored in checkpoint["rng_states"] (list indexed by rank). Only rank 0 writes the file.
65
- In single-process mode the checkpoint contains a single-item rng_states list.
66
- """
67
- os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
68
-
69
- # prepare local RNG state
70
- local_rng = _get_local_rng_state()
71
-
72
- # distributed path: gather RNG states from all ranks
73
- if dist.is_available() and dist.is_initialized():
74
- rank = dist.get_rank()
75
- world_size = dist.get_world_size()
76
- all_states = [None] * world_size
77
- # gather python objects (picklable)
78
- dist.all_gather_object(all_states, local_rng)
79
-
80
- # only rank 0 writes the checkpoint file
81
- if rank == 0:
82
- checkpoint = {}
83
- if epoch is not None:
84
- checkpoint["epoch"] = epoch
85
- checkpoint["rng_states"] = all_states
86
-
87
- # save provided objects' state_dicts (rank 0's local state_dicts)
88
- for name, obj in named_objects.items():
89
- checkpoint[name] = obj.state_dict()
90
-
91
- torch.save(checkpoint, ckpt_path)
92
-
93
- # ensure everyone waits until rank 0 finished writing
94
- dist.barrier()
95
-
96
- else:
97
- # single-process fallback
98
- checkpoint = {}
99
- if epoch is not None:
100
- checkpoint["epoch"] = epoch
101
- checkpoint["rng_states"] = [local_rng]
102
- for name, obj in named_objects.items():
103
- checkpoint[name] = obj.state_dict()
104
- torch.save(checkpoint, ckpt_path)
105
-
106
-
107
- def load_state(ckpt_path: str, **named_objects: Any) -> int:
108
- """
109
- Load checkpoint saved by save_state.
110
-
111
- - Each process loads the same file and restores its own RNG state (checkpoint['rng_states'][rank]).
112
- - Named objects that exist in the checkpoint will have their state_dict loaded.
113
- - Returns epoch (int) if present, otherwise 0.
114
- """
115
- if not os.path.isfile(ckpt_path):
116
- warnings.warn(f"Checkpoint file not found: {ckpt_path}")
117
- return 0
118
-
119
- # load on all ranks (shared FS assumed)
120
- checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
121
- epoch = checkpoint.get("epoch", 0)
122
-
123
- # load state_dicts into provided objects
124
- for name, obj in named_objects.items():
125
- if name in checkpoint:
126
- try:
127
- obj.load_state_dict(checkpoint[name])
128
- except Exception as e:
129
- warnings.warn(f"Failed to load state_dict for '{name}': {e}")
130
- else:
131
- warnings.warn(f"No state_dict found for '{name}' in checkpoint.")
132
-
133
- # restore this rank's RNG state
134
- rng_states = checkpoint.get("rng_states", None)
135
- if rng_states is not None:
136
- if dist.is_available() and dist.is_initialized():
137
- rank = dist.get_rank()
138
- if rank < len(rng_states):
139
- my_state = rng_states[rank]
140
- else:
141
- my_state = None
142
- else:
143
- # single-process file: first element
144
- my_state = rng_states[0] if len(rng_states) > 0 else None
145
-
146
- try:
147
- _set_local_rng_state(my_state)
148
- except Exception as e:
149
- warnings.warn(f"Failed to restore RNG state: {e}")
150
-
151
- else:
152
- warnings.warn("No 'rng_states' found in checkpoint; RNGs not restored.")
153
-
154
- return epoch
155
-
156
-
157
- def extract_model_from_checkpoint_dinov2(checkpoint_path: str):
158
- # Load the checkpoint
159
- checkpoint = torch.load(
160
- checkpoint_path,
161
- weights_only=False,
162
- map_location="cpu"
163
- )
164
-
165
- # Get model state dict
166
- model_state = checkpoint.get("model", checkpoint)
167
-
168
- # Create output folder
169
- output_dir = str(checkpoint_path).replace(".pt", "")
170
- os.makedirs(output_dir, exist_ok=True)
171
-
172
- # Quick check: compare the parameters of head_teacher_ibot vs head_teacher_dino
173
- teacher_dino_keys = [k for k in model_state.keys() if k.startswith("head_teacher_dino.")]
174
- teacher_ibot_keys = [k for k in model_state.keys() if k.startswith("head_teacher_ibot.")]
175
-
176
- ibot_separate = True
177
- if teacher_dino_keys and teacher_ibot_keys:
178
- if all(torch.equal(model_state[dino_key], model_state[ibot_key]) \
179
- for dino_key, ibot_key in zip(teacher_dino_keys, teacher_ibot_keys)):
180
- ibot_separate = False # Same weights → no separate ibot head
181
-
182
- # Define the components to save
183
- components = {
184
- "backbone_teacher.pt": "backbone_teacher.vit",
185
- "backbone_student.pt": "backbone_student.vit",
186
- "head_student_dino.pt": "head_student_dino",
187
- "head_teacher_dino.pt": "head_teacher_dino"
188
- }
189
-
190
- # Add ibot heads only if separate
191
- if ibot_separate:
192
- components["head_student_ibot.pt"] = "head_student_ibot"
193
- components["head_teacher_ibot.pt"] = "head_teacher_ibot"
194
-
195
- # Extract and save each component
196
- for filename, key in components.items():
197
- sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
198
- if not sub_state_dict:
199
- print(f"[WARNING] No parameters found for {key}, skipping...")
200
- continue
201
- torch.save(sub_state_dict, os.path.join(output_dir, filename))
202
-
203
- print(f"Components extracted to: {output_dir}")
204
-
205
-
206
- def extract_model_from_checkpoint_siglip(checkpoint_path: str):
207
- # Load the checkpoint
208
- checkpoint = torch.load(
209
- checkpoint_path,
210
- weights_only=False,
211
- map_location="cpu",
212
- )
213
-
214
- # Get model state dict
215
- model_state = checkpoint.get("model", checkpoint)
216
-
217
- # Create output folder
218
- output_dir = str(checkpoint_path).replace(".pt", "")
219
- os.makedirs(output_dir, exist_ok=True)
220
-
221
- # Define the components to save
222
- components = {
223
- "backbone_image.pt": "backbone_image",
224
- "backbone_text.pt": "backbone_text",
225
- "feature_comb_image.pt": "feature_comb_image",
226
- "projection_image.pt": "projection_image",
227
- "projection_text.pt": "projection_text"
228
- }
229
-
230
- # Extract and save each component
231
- for filename, key in components.items():
232
- sub_state_dict = {k.replace(f"{key}.", ""): v for k, v in model_state.items() if k.startswith(key)}
233
- if not sub_state_dict:
234
- print(f"[WARNING] No parameters found for {key}, skipping...")
235
- continue
236
- torch.save(sub_state_dict, os.path.join(output_dir, filename))
237
-
238
- print(f"Components extracted to: {output_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/collate.py DELETED
@@ -1,120 +0,0 @@
1
- from typing import List, Callable, Optional
2
-
3
- import torch
4
-
5
- MONAI_IMPORT_ERROR = None
6
- try:
7
- from monai.data import list_data_collate
8
- except ImportError as e:
9
- list_data_collate = lambda x: x # type: ignore
10
- MONAI_IMPORT_ERROR = e
11
-
12
-
13
- def extended_collate_dino(samples_list: List) -> dict:
14
- """
15
- Applies MONAI's list_data_collate first and then extends it with DINOv2 masking logic.
16
-
17
- Args:
18
- samples_list: List of samples containing 'global_crops' and 'local_crops'.
19
- mask_ratio: Tuple defining the range of masking ratios.
20
- mask_probability: Probability of applying masking.
21
- dtype: Data type to cast the collated tensors.
22
- n_tokens: Number of tokens for masking.
23
- mask_generator: Function to generate masks.
24
-
25
- Returns:
26
- A dictionary with collated global/local crops and corresponding masks.
27
- """
28
- if MONAI_IMPORT_ERROR is not None:
29
- raise ImportError(
30
- "MONAI is required to use extended_collate_dino but not installed. "
31
- "Please install MONAI to use this collate function."
32
- ) from MONAI_IMPORT_ERROR
33
-
34
- # Apply MONAI's list_data_collate
35
- collated_data = list_data_collate(samples_list)
36
-
37
- # Extract crops
38
- global_views = torch.cat(collated_data["image_global_views"], dim=0)
39
- local_views = torch.cat(collated_data["image_local_views"], dim=0)
40
-
41
- return {
42
- "global_views": global_views,
43
- "local_views": local_views,
44
- }
45
-
46
-
47
- def extended_collate_siglip(
48
- samples_list: List,
49
- tokenizer: Optional[Callable] = None,
50
- tokenizer_padding: bool = True,
51
- tokenizer_truncation: bool = True,
52
- tokenizer_max_length: Optional[int] = 1024,
53
- return_filenames: bool = False
54
- ) -> dict:
55
- """
56
- Applies SigLIP collate and then extends it with tokenization logic.
57
-
58
- Args:
59
- samples_list: List of samples containing 'image' and 'report'.
60
- tokenizer: Tokenizer function to apply on the reports.
61
-
62
- Returns:
63
- A dictionary with collated images and tokenized text.
64
- """
65
- if MONAI_IMPORT_ERROR is not None:
66
- raise ImportError(
67
- "MONAI is required to use extended_collate_siglip but not installed. "
68
- "Please install MONAI to use this collate function."
69
- ) from MONAI_IMPORT_ERROR
70
-
71
- collated_data = list_data_collate(samples_list)
72
-
73
- if return_filenames:
74
- if "image" in collated_data.keys():
75
- if (
76
- hasattr(samples_list[0]["image"].data, "meta")
77
- and "filename_or_obj" in samples_list[0]["image"].data.meta
78
- ):
79
- collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
80
-
81
- if tokenizer is not None and "report" in collated_data.keys():
82
- tokenizer_output = tokenizer.batch_encode_plus(
83
- collated_data["report"],
84
- add_special_tokens=True,
85
- padding=tokenizer_padding,
86
- truncation=tokenizer_truncation,
87
- max_length=tokenizer_max_length,
88
- )
89
-
90
- collated_data["input_ids"] = torch.tensor(tokenizer_output["input_ids"])
91
- collated_data["attention_mask"] = torch.tensor(tokenizer_output["attention_mask"])
92
-
93
- return collated_data
94
-
95
-
96
- def collate_add_filenames(samples_list: List) -> dict:
97
- """
98
- Applies MONAI's list_data_collate and adds filenames to the collated output.
99
-
100
- Args:
101
- samples_list: List of samples containing 'image' with metadata.
102
- Returns:
103
- A dictionary with collated images and filenames.
104
- """
105
- if MONAI_IMPORT_ERROR is not None:
106
- raise ImportError(
107
- "MONAI is required to use collate_add_filenames but not installed. "
108
- "Please install MONAI to use this collate function."
109
- ) from MONAI_IMPORT_ERROR
110
-
111
- collated_data = list_data_collate(samples_list)
112
-
113
- if "image" in collated_data.keys():
114
- if (
115
- hasattr(samples_list[0]["image"].data, "meta")
116
- and "filename_or_obj" in samples_list[0]["image"].data.meta
117
- ):
118
- collated_data["filename"] = [s["image"].data.meta["filename_or_obj"] for s in samples_list]
119
-
120
- return collated_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/config.py DELETED
@@ -1,91 +0,0 @@
1
- import os
2
- import math
3
-
4
- from spectre.utils import _utils, distributed
5
-
6
- OMEGACONF_IMPORT_ERROR = None
7
- try:
8
- from omegaconf import OmegaConf
9
- except ImportError as e:
10
- OmegaConf = None # type: ignore
11
- OMEGACONF_IMPORT_ERROR = e
12
-
13
-
14
- def apply_scaling_rules_to_cfg(cfg):
15
- """
16
- Apply learing rate scaling rules to the configuration object.
17
- """
18
- base_lr = cfg.optim.base_lr
19
- cfg.optim.lr = base_lr
20
-
21
- # Apply scaling rules
22
- if cfg.optim.scaling_rule == "constant":
23
- return cfg
24
-
25
- try:
26
- scaling_type, ref_batch_size = cfg.optim.scaling_rule.split("_wrt_")
27
- ref_batch_size = float(ref_batch_size)
28
- except ValueError:
29
- raise NotImplementedError(f"Unknown scaling rule: {cfg.optim.scaling_rule}")
30
-
31
- scale_factor = cfg.train.batch_size_per_gpu * distributed.get_global_size()
32
- scale_factor /= ref_batch_size
33
- scale_factor *= cfg.train.grad_accum_steps
34
-
35
- if scaling_type == "sqrt":
36
- cfg.optim.lr *= math.sqrt(scale_factor)
37
- elif scaling_type == "linear":
38
- cfg.optim.lr *= scale_factor
39
- else:
40
- raise NotImplementedError(f"Unsupported scaling type: {scaling_type}")
41
-
42
- return cfg
43
-
44
-
45
- def write_config(cfg, output_dir, name="config.yaml"):
46
- if OMEGACONF_IMPORT_ERROR is not None:
47
- raise ImportError(
48
- "OmegaConf is required to use write_config but not installed. "
49
- "Please install OmegaConf to use this function."
50
- ) from OMEGACONF_IMPORT_ERROR
51
-
52
- saved_cfg_path = os.path.join(output_dir, name)
53
- with open(saved_cfg_path, "w") as f:
54
- OmegaConf.save(config=cfg, f=f)
55
- return saved_cfg_path
56
-
57
-
58
- def get_cfg_from_args(args, default_config):
59
- if OMEGACONF_IMPORT_ERROR is not None:
60
- raise ImportError(
61
- "OmegaConf is required to use get_cfg_from_args but not installed. "
62
- "Please install OmegaConf to use this function."
63
- ) from OMEGACONF_IMPORT_ERROR
64
-
65
- args.output_dir = os.path.abspath(args.output_dir)
66
- args.opts = [] if args.opts is None else args.opts
67
- args.opts += [f"train.output_dir={args.output_dir}"]
68
- default_cfg = OmegaConf.create(default_config)
69
- cfg = OmegaConf.load(args.config_file)
70
- cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
71
- return cfg
72
-
73
-
74
- def random_seed(args):
75
- seed = getattr(args, "seed", 0)
76
- rank = distributed.get_global_rank()
77
-
78
- _utils.fix_random_seeds(seed + rank)
79
-
80
-
81
- def setup(args, default_config):
82
- """
83
- Create configs and perform basic setups.
84
- """
85
- cfg = get_cfg_from_args(args, default_config)
86
- os.makedirs(args.output_dir, exist_ok=True)
87
- random_seed(args)
88
- accelerator = distributed.init_distributed(cfg)
89
- apply_scaling_rules_to_cfg(cfg)
90
- write_config(cfg, args.output_dir)
91
- return cfg, accelerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/dataloader.py DELETED
@@ -1,126 +0,0 @@
1
- from __future__ import annotations
2
- import os
3
- from typing import Union, Callable, Optional, List
4
-
5
- import torch
6
- from torch.utils.data import ConcatDataset
7
-
8
- MONAI_IMPORT_ERROR = None
9
- try:
10
- import monai.data as data
11
- except ImportError as e:
12
- data = None # type: ignore
13
- MONAI_IMPORT_ERROR = e
14
-
15
-
16
-
17
- def get_dataloader(
18
- datasets: Union[str, List[str]],
19
- data_dir: str,
20
- include_reports: bool = False,
21
- include_labels: bool = False,
22
- cache_dataset: bool = False,
23
- cache_dir: Optional[str] = None,
24
- use_gds: bool = False,
25
- transform: Optional[Callable] = None,
26
- fraction: float = 1.0,
27
- batch_size: int = 64,
28
- num_workers: int = 4,
29
- pin_memory: bool = True,
30
- shuffle: bool = True,
31
- collate_fn: Optional[Callable] = None,
32
- drop_last: bool = True,
33
- persistent_workers: bool = True,
34
- use_thread: bool = False,
35
- ) -> "DataLoader":
36
- """
37
- Get dataloader for training.
38
- """
39
- if MONAI_IMPORT_ERROR is not None:
40
- raise ImportError(
41
- "MONAI is required to use get_dataloader but not installed. "
42
- "Please install MONAI to use this function."
43
- ) from MONAI_IMPORT_ERROR
44
-
45
- if isinstance(datasets, str):
46
- datasets = [datasets]
47
-
48
- # Validate constraints
49
- if include_reports:
50
- assert set(datasets).issubset({"ct_rate", "merlin", "inspect"}), \
51
- "When include_reports=True, only 'ct_rate', 'merlin', and 'inspect' are allowed."
52
- if include_labels:
53
- assert set(datasets).issubset({"abdomen_atlas", "abdomenct_1k"}), \
54
- "When include_labels=True, only 'abdomen_atlas' and 'abdomenct_1k' are allowed."
55
- if use_gds:
56
- assert cache_dataset, "GDS requires cache_dataset=True."
57
- assert torch.cuda.is_available(), "GDS requires CUDA to be available."
58
-
59
- # Dataset configurations
60
- DATASET_CONFIGS = {
61
- "ct_rate": {"folder": "CT-RATE", "base_name": "CTRate",
62
- "extra": {"include_reports": include_reports}},
63
- "inspect": {"folder": "INSPECT", "base_name": "Inspect",
64
- "extra": {"include_reports": include_reports}},
65
- "merlin": {"folder": "MERLIN", "base_name": "Merlin",
66
- "extra": {"include_reports": include_reports}},
67
- "nlst": {"folder": "NLST", "base_name": "Nlst"},
68
- "amos": {"folder": "Amos", "base_name": "Amos"},
69
- "abdomen_atlas": {"folder": "AbdomenAtlas1.0Mini", "base_name": "AbdomenAtlas",
70
- "extra": {"include_labels": include_labels}},
71
- "panorama": {"folder": "PANORAMA", "base_name": "Panorama"},
72
- "abdomenct_1k": {"folder": "AbdomenCT-1K", "base_name": "AbdomenCT1K",
73
- "extra": {"include_labels": include_labels}},
74
- }
75
-
76
- datasets_list = []
77
- for ds in datasets:
78
- if ds.lower() not in DATASET_CONFIGS:
79
- raise NotImplementedError(f"Dataset {ds} not implemented.")
80
-
81
- cfg = DATASET_CONFIGS[ds.lower()]
82
- folder = cfg["folder"]
83
- extra_args = cfg.get("extra", {})
84
-
85
- kwargs = {
86
- "data_dir": os.path.join(data_dir, folder),
87
- "transform": transform,
88
- "fraction": fraction,
89
- **extra_args,
90
- }
91
-
92
- base_name = cfg["base_name"]
93
- class_suffix = "Dataset"
94
- if cache_dataset:
95
- class_suffix = "GDSDataset" if use_gds else "PersistentDataset"
96
-
97
- class_name = f"{base_name}{class_suffix}"
98
- DatasetClass = getattr(__import__("spectre.data", fromlist=[class_name]), class_name)
99
-
100
- if cache_dataset:
101
- kwargs["cache_dir"] = os.path.join(cache_dir, folder)
102
- if use_gds:
103
- kwargs["device"] = torch.cuda.current_device()
104
-
105
- datasets_list.append(DatasetClass(**kwargs))
106
-
107
- dataset = datasets_list[0] if len(datasets_list) == 1 else ConcatDataset(datasets_list)
108
-
109
- loader_cls = getattr(data, "ThreadDataLoader" if use_thread else "DataLoader")
110
- loader_kwargs = {
111
- "dataset": dataset,
112
- "batch_size": batch_size,
113
- "num_workers": num_workers,
114
- "shuffle": shuffle,
115
- "drop_last": drop_last,
116
- }
117
-
118
- if not use_thread:
119
- loader_kwargs.update({
120
- "pin_memory": pin_memory,
121
- "persistent_workers": persistent_workers
122
- })
123
- if collate_fn is not None:
124
- loader_kwargs["collate_fn"] = collate_fn
125
-
126
- return loader_cls(**loader_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/distributed.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
-
3
- import torch.distributed as dist
4
-
5
- ACCELERATE_IMPORT_ERROR = None
6
- try:
7
- from accelerate import Accelerator, DataLoaderConfiguration
8
- except ImportError as e:
9
- Accelerator = None # type: ignore
10
- DataLoaderConfiguration = None # type: ignore
11
- ACCELERATE_IMPORT_ERROR = e
12
-
13
-
14
- def is_enabled() -> bool:
15
- """
16
- Returns:
17
- True if distributed training is enabled
18
- """
19
- return dist.is_available() and dist.is_initialized()
20
-
21
-
22
- def get_global_size() -> int:
23
- """
24
- Returns:
25
- Number of processes in the distributed group
26
- """
27
- if not is_enabled():
28
- return 1
29
- return dist.get_world_size()
30
-
31
-
32
- def get_global_rank() -> int:
33
- """
34
- Returns:
35
- The rank of the current process in the distributed group
36
- """
37
- if not is_enabled():
38
- return 0
39
- return dist.get_rank()
40
-
41
-
42
- def get_local_size() -> int:
43
- """
44
- Returns:
45
- Number of processes on the current machine
46
- """
47
- if not is_enabled():
48
- return 1
49
- return int(os.environ.get("LOCAL_SIZE", 1))
50
-
51
-
52
- def get_local_rank() -> int:
53
- """
54
- Returns:
55
- The rank of the current process on the current machine
56
- """
57
- if not is_enabled():
58
- return 0
59
- return int(os.environ.get("LOCAL_RANK", 0))
60
-
61
-
62
- def init_distributed(cfg):
63
- """
64
- Initialize distributed training.
65
- """
66
- if ACCELERATE_IMPORT_ERROR is not None:
67
- raise ImportError(
68
- "Accelerate is required to use init_distributed but not installed. "
69
- "Please install Accelerate to use this function."
70
- ) from ACCELERATE_IMPORT_ERROR
71
-
72
- # Initialize accelerator
73
- dataloader_config = DataLoaderConfiguration(
74
- non_blocking=cfg.train.pin_memory,
75
- )
76
- accelerator = Accelerator(
77
- gradient_accumulation_steps=cfg.train.grad_accum_steps,
78
- log_with="wandb" if cfg.train.log_wandb else None,
79
- dataloader_config=dataloader_config,
80
- )
81
-
82
- # Initialize wandb
83
- if cfg.train.log_wandb:
84
- accelerator.init_trackers(
85
- project_name="spectre",
86
- config={k: v for d in cfg.values() for k, v in d.items()},
87
- init_kwargs={
88
- "dir": os.path.join(cfg.train.output_dir, "logs"),
89
- },
90
- )
91
-
92
- return accelerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/lora.py DELETED
@@ -1,38 +0,0 @@
1
- import torch.nn as nn
2
- import loralib as lora
3
-
4
-
5
- def add_lora_adapters(
6
- root_module: nn.Module,
7
- r: int = 8,
8
- lora_alpha: int = 32,
9
- lora_dropout: float = 0.05,
10
- target_keywords: tuple[str, ...] = ("q_proj", "k_proj", "v_proj", "o_proj")
11
- ) -> None:
12
- """
13
- Recursively traverses the model and replaces every `nn.Linear`
14
- whose name contains one of `target_keywords` with a LoRA-augmented
15
- linear layer from loralib.
16
- """
17
-
18
- for name, child in list(root_module.named_children()):
19
- # If the child is itself a container, recurse first
20
- add_lora_adapters(child, r, lora_alpha, lora_dropout, target_keywords)
21
-
22
- # Replace target linear layers
23
- if isinstance(child, nn.Linear) and any(k in name for k in target_keywords):
24
- lora_layer = lora.Linear( # loralib wrapper
25
- in_features=child.in_features,
26
- out_features=child.out_features,
27
- r=r,
28
- lora_alpha=lora_alpha,
29
- lora_dropout=lora_dropout,
30
- bias=child.bias is not None,
31
- )
32
-
33
- # copy original weights so that behaviour is identical pre-training
34
- lora_layer.weight.data = child.weight.data.clone()
35
- if child.bias is not None:
36
- lora_layer.bias.data = child.bias.data.clone()
37
-
38
- setattr(root_module, name, lora_layer) # hot-swap!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/masking.py DELETED
@@ -1,196 +0,0 @@
1
- import math
2
- from typing import Optional, Tuple, Union
3
-
4
- import torch
5
-
6
-
7
- def _random_block_mask(
8
- size: Tuple[int, int, int],
9
- num_masks: int,
10
- min_num_masks_per_block: int = 4,
11
- max_num_masks_per_block: Optional[int] = None,
12
- max_attempts_per_block: int = 10,
13
- generator: Optional[torch.Generator] = None,
14
- device: Optional[Union[torch.device, str]] = None,
15
- ) -> torch.Tensor:
16
- """3D helper: generate a (H, W, D) boolean mask by placing cuboidal blocks.
17
-
18
- - size: (H, W, D)
19
- - num_masks: target total number of masked voxels for this image
20
- - min_num_masks_per_block / max_num_masks_per_block: voxel-range per block
21
- """
22
- H, W, D = size
23
- total = H * W * D
24
- num_masks = min(max(0, int(num_masks)), total)
25
-
26
- if max_num_masks_per_block is None:
27
- max_num_masks_per_block = max(1, num_masks)
28
-
29
- mask = torch.zeros((H, W, D), dtype=torch.bool, device=device)
30
- masked_count = 0
31
- global_attempts = 0
32
-
33
- orders = [(0, 1, 2), (1, 2, 0), (2, 0, 1)]
34
-
35
- # Try to place blocks until we have enough masked voxels or we exceed attempts
36
- while masked_count < num_masks and global_attempts < max_attempts_per_block:
37
- global_attempts += 1
38
-
39
- # choose target voxels for this block
40
- target_voxels = int(torch.randint(
41
- min_num_masks_per_block, max_num_masks_per_block + 1, (1,), generator=generator
42
- ).item())
43
-
44
- found = False
45
- local_attempts = 0
46
- while not found and local_attempts < max_attempts_per_block:
47
- local_attempts += 1
48
-
49
- # random pick order for dims to reduce bias
50
- order_idx = int(torch.randint(0, 3, (1,), generator=generator).item())
51
- order = orders[order_idx]
52
-
53
- # pick first dimension
54
- if order[0] == 0:
55
- h = int(torch.randint(1, min(H, target_voxels) + 1, (1,), generator=generator).item())
56
- elif order[0] == 1:
57
- w = int(torch.randint(1, min(W, target_voxels) + 1, (1,), generator=generator).item())
58
- else:
59
- d = int(torch.randint(1, min(D, target_voxels) + 1, (1,), generator=generator).item())
60
-
61
- # progressively choose remaining dims while ensuring feasibility
62
- try:
63
- if order[0] == 0:
64
- # h chosen -> pick w then compute d_needed
65
- max_w = max(1, min(W, target_voxels // h))
66
- w = int(torch.randint(1, max_w + 1, (1,), generator=generator).item())
67
- d_needed = math.ceil(target_voxels / (h * w))
68
- if d_needed <= D:
69
- d = max(1, d_needed)
70
- found = True
71
- elif order[0] == 1:
72
- # w chosen -> pick d then compute h_needed
73
- max_d = max(1, min(D, target_voxels // w))
74
- d = int(torch.randint(1, max_d + 1, (1,), generator=generator).item())
75
- h_needed = math.ceil(target_voxels / (d * w))
76
- if h_needed <= H:
77
- h = max(1, h_needed)
78
- found = True
79
- else:
80
- # d chosen -> pick h then compute w_needed
81
- max_h = max(1, min(H, target_voxels // d))
82
- h = int(torch.randint(1, max_h + 1, (1,), generator=generator).item())
83
- w_needed = math.ceil(target_voxels / (d * h))
84
- if w_needed <= W:
85
- w = max(1, w_needed)
86
- found = True
87
-
88
- except ValueError:
89
- # in case of invalid ranges (defensive); just continue trying
90
- continue
91
-
92
- # fallback alternative attempt: try simple factorization heuristics
93
- if not found:
94
- # attempt small-to-large factorization
95
- for hh in range(1, min(H, target_voxels) + 1):
96
- for ww in range(1, min(W, target_voxels // hh) + 1):
97
- dd = math.ceil(target_voxels / (hh * ww))
98
- if dd <= D:
99
- h, w, d = hh, ww, dd
100
- found = True
101
- break
102
- if found:
103
- break
104
-
105
- if not found:
106
- # couldn't find a fitting block this global attempt; move on
107
- continue
108
-
109
- # clamp block dims to volume just in case and ensure at least 1
110
- h = min(max(1, int(h)), H)
111
- w = min(max(1, int(w)), W)
112
- d = min(max(1, int(d)), D)
113
-
114
- # choose random location so block fits
115
- x0 = int(torch.randint(0, (H - h) + 1, (1,), generator=generator).item()) if H - h > 0 else 0
116
- y0 = int(torch.randint(0, (W - w) + 1, (1,), generator=generator).item()) if W - w > 0 else 0
117
- z0 = int(torch.randint(0, (D - d) + 1, (1,), generator=generator).item()) if D - d > 0 else 0
118
-
119
- mask[x0 : x0 + h, y0 : y0 + w, z0 : z0 + d] = True
120
- masked_count = int(mask.sum().item())
121
-
122
- # If still short, fill remaining voxels at random positions
123
- if masked_count < num_masks:
124
- remaining = num_masks - masked_count
125
- indices = torch.nonzero(~mask, as_tuple=False)
126
- if indices.numel() > 0:
127
- perm = torch.randperm(indices.shape[0], generator=generator, device=mask.device)
128
- pick = indices[perm[:remaining]]
129
- mask[pick[:, 0], pick[:, 1], pick[:, 2]] = True
130
-
131
- return mask
132
-
133
-
134
- def random_block_mask(
135
- size: Tuple[int, int, int, int],
136
- batch_mask_ratio: float = 0.5,
137
- min_image_mask_ratio: float = 0.1,
138
- max_image_mask_ratio: float = 0.5,
139
- min_num_masks_per_block: int = 4,
140
- max_num_masks_per_block: Optional[int] = None,
141
- max_attempts_per_block: int = 10,
142
- generator: Optional[torch.Generator] = None,
143
- device: Optional[Union[torch.device, str]] = None,
144
- ) -> torch.Tensor:
145
- """Create random block masks for 3D volumes only.
146
-
147
- Args:
148
- size: (B, H, W, D)
149
- batch_mask_ratio: fraction of images in the batch to apply masking to
150
- min_image_mask_ratio / max_image_mask_ratio: per-image mask fraction range
151
- min_num_masks_per_block / max_num_masks_per_block: voxels per block range
152
- max_attempts_per_block: attempts to find a fitting block
153
- generator: optional torch.Generator for reproducibility.
154
- device: device for returned tensor
155
-
156
- Returns:
157
- boolean tensor with shape (B, H, W, D)
158
- """
159
- if len(size) != 4:
160
- raise ValueError("size must be (B, H, W, D) for 3D masking.")
161
-
162
- B, H, W, D = size
163
-
164
- if max_image_mask_ratio < min_image_mask_ratio:
165
- raise ValueError("max_image_mask_ratio must be >= min_image_mask_ratio.")
166
-
167
- num_images_masked = int(B * batch_mask_ratio)
168
- probs = torch.linspace(min_image_mask_ratio, max_image_mask_ratio, num_images_masked + 1).tolist()
169
-
170
- image_masks = []
171
- total_voxels = H * W * D
172
-
173
- for prob_min, prob_max in zip(probs[:-1], probs[1:]):
174
- # choose number of masked voxels for this image
175
- u = float(prob_min + (prob_max - prob_min) * torch.rand(1, generator=generator).item())
176
- num_mask = int(total_voxels * u)
177
- image_masks.append(
178
- _random_block_mask(
179
- size=(H, W, D),
180
- num_masks=num_mask,
181
- min_num_masks_per_block=min_num_masks_per_block,
182
- max_num_masks_per_block=max_num_masks_per_block,
183
- max_attempts_per_block=max_attempts_per_block,
184
- generator=generator,
185
- device=device,
186
- )
187
- )
188
-
189
- # Add non-masked images (all False) to fill the batch
190
- for _ in range(num_images_masked, B):
191
- image_masks.append(torch.zeros((H, W, D), dtype=torch.bool, device=device))
192
-
193
- perm = torch.randperm(B, generator=generator).tolist()
194
- image_masks = [image_masks[i] for i in perm]
195
-
196
- return torch.stack(image_masks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/param_groups.py DELETED
@@ -1,118 +0,0 @@
1
- import torch.nn as nn
2
-
3
-
4
- def get_vit_lr_decay_rate(
5
- name: str,
6
- llrd_factor: float = 1.0,
7
- num_layers: int = 12,
8
- force_is_backbone: bool = False,
9
- shift: int = 0,
10
- ) -> float:
11
- """
12
- Get the layer-wise learning rate decay (LLRD) rate for a given parameter name.
13
-
14
- Args:
15
- name:
16
- The name of the parameter.
17
- llrd_factor:
18
- The decay factor for each layer.
19
- num_layers:
20
- The total number of layers in the model.
21
- force_is_backbone:
22
- If True, forces the function to treat the parameter as part of the backbone.
23
- shift:
24
- An integer to shift the layer ids, useful when combining multiple modules.
25
-
26
- Returns:
27
- The learning rate multiplier for the parameter.
28
- """
29
- layer_id = num_layers + 1
30
- if name.startswith("backbone") or force_is_backbone:
31
- if (
32
- ".pos_embed" in name
33
- or ".patch_embed" in name
34
- or ".patch_proj" in name
35
- or ".mask_token" in name
36
- or ".cls_token" in name
37
- or ".reg_token" in name
38
- ):
39
- layer_id = 0
40
- elif ".blocks." in name:
41
- layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + shift
42
-
43
- return llrd_factor ** (num_layers + 1 - layer_id)
44
-
45
-
46
- def get_param_groups_with_decay(
47
- model: nn.Module,
48
- llrd_factor: float = 1.0,
49
- patch_embed_lr_mult: float = 1.0,
50
- projection_head_wd_mult: float = 1.0,
51
- lora_lr_factor: float = 1.0,
52
- num_layers: int | None = None,
53
- ):
54
-
55
- force_is_backbone = False
56
- shift = 0
57
- if num_layers is not None:
58
- num_layers = num_layers
59
- elif hasattr(model, "n_blocks"):
60
- num_layers = model.n_blocks
61
- force_is_backbone = True
62
- elif hasattr(model, "blocks"):
63
- num_layers = len(model.blocks)
64
- force_is_backbone = True
65
- elif hasattr(model, "backbone") and hasattr(model.backbone, "blocks"):
66
- num_layers = len(model.backbone.blocks)
67
- elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "blocks"): # DINO specific
68
- num_layers = len(model.backbone_student.blocks)
69
- elif hasattr(model, "backbone_student") and hasattr(model.backbone_student, "vit") and hasattr(model.backbone_student.vit, "blocks"): # DINOv2 specific
70
- num_layers = len(model.backbone_student.vit.blocks)
71
- elif hasattr(model, "backbone_image") and hasattr(model.backbone_image, "blocks"): # SigLIP specific
72
- if not hasattr(model, "feature_comb_image") or model.feature_comb_image is None:
73
- num_layers = len(model.backbone_image.blocks)
74
- else:
75
- num_layers = len(model.backbone_image.blocks) + len(model.feature_comb_image.blocks)
76
- shift = len(model.backbone_image.blocks)
77
- force_is_backbone = True
78
- else:
79
- num_layers = 0
80
-
81
- all_param_groups = []
82
- for n, p in model.named_parameters():
83
- if not p.requires_grad:
84
- continue
85
- if not "lora_" in n:
86
- s = shift if "feature_comb" in n else 0
87
- llrd_rate = get_vit_lr_decay_rate(
88
- n, llrd_factor, num_layers, force_is_backbone, s,
89
- )
90
-
91
- d = {
92
- "name": n,
93
- "params": p,
94
- "lr_mult": llrd_rate,
95
- "wd_mult": 1.0,
96
- }
97
-
98
- if "head" in n or "projection" in n:
99
- d["wd_mult"] = projection_head_wd_mult
100
-
101
- # No weight-decay on biases, norm parameters, layer scale gamma, learned tokens and embeddings
102
- if n.endswith("bias") or "norm" in n or "gamma" in n or "fourrier_w" in n:
103
- d["wd_mult"] = 0.0
104
-
105
- if "patch_embed" in n:
106
- d["lr_mult"] *= patch_embed_lr_mult
107
-
108
- else:
109
- # LoRA parameters
110
- d = {
111
- "name": n,
112
- "params": p,
113
- "lr_mult": lora_lr_factor,
114
- "wd_mult": 1.0,
115
- }
116
-
117
- all_param_groups.append(d)
118
- return all_param_groups
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spectre/utils/scheduler.py DELETED
@@ -1,236 +0,0 @@
1
- import warnings
2
- from typing import Optional
3
-
4
- import numpy as np
5
- import torch
6
-
7
-
8
- def linear_warmup_schedule(
9
- step: int,
10
- warmup_steps: int,
11
- start_value: float,
12
- end_value: float,
13
- ) -> float:
14
- if warmup_steps < 0:
15
- raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
16
- if step < 0:
17
- raise ValueError(f"Current step number {step} can't be negative.")
18
- if start_value < 0:
19
- raise ValueError(f"Start value {start_value} can't be negative.")
20
- if end_value <= 0:
21
- raise ValueError(f"End value {end_value} can't be non-positive.")
22
- if start_value > end_value:
23
- raise ValueError(
24
- f"Start value {start_value} must be less than or equal to end value {end_value}."
25
- )
26
- if step < warmup_steps:
27
- return start_value + step / warmup_steps * (end_value - start_value)
28
- else:
29
- return end_value
30
-
31
-
32
- def cosine_schedule(
33
- step: int,
34
- max_steps: int,
35
- start_value: float,
36
- end_value: float,
37
- period: Optional[int] = None,
38
- ) -> float:
39
- """Use cosine decay to gradually modify start_value to reach target end_value.
40
-
41
- Args:
42
- step:
43
- Current step number.
44
- max_steps:
45
- Total number of steps.
46
- start_value:
47
- Starting value.
48
- end_value:
49
- Target value.
50
- period:
51
- The number of steps over which the cosine function completes a full cycle.
52
- Defaults to max_steps.
53
-
54
- Returns:
55
- Cosine decay value.
56
-
57
- """
58
- if step < 0:
59
- raise ValueError(f"Current step number {step} can't be negative.")
60
- if max_steps < 1:
61
- raise ValueError(f"Total step number {max_steps} must be >= 1.")
62
- if period is None and step > max_steps:
63
- warnings.warn(
64
- f"Current step number {step} exceeds max_steps {max_steps}.",
65
- category=RuntimeWarning,
66
- )
67
- if period is not None and period <= 0:
68
- raise ValueError(f"Period {period} must be >= 1")
69
-
70
- decay: float
71
- if period is not None: # "cycle" based on period, if provided
72
- decay = (
73
- end_value
74
- - (end_value - start_value) * (np.cos(2 * np.pi * step / period) + 1) / 2
75
- )
76
- elif max_steps == 1:
77
- # Avoid division by zero
78
- decay = end_value
79
- elif step == max_steps:
80
- # Special case for Pytorch Lightning which updates LR scheduler also for epoch
81
- # after last training epoch.
82
- decay = end_value
83
- else:
84
- decay = (
85
- end_value
86
- - (end_value - start_value)
87
- * (np.cos(np.pi * step / (max_steps - 1)) + 1)
88
- / 2
89
- )
90
- return decay
91
-
92
-
93
- def cosine_warmup_schedule(
94
- step: int,
95
- max_steps: int,
96
- start_value: float,
97
- end_value: float,
98
- warmup_steps: int,
99
- warmup_start_value: float,
100
- warmup_end_value: Optional[float] = None,
101
- period: Optional[int] = None,
102
- ) -> float:
103
- """Use cosine decay to gradually modify start_value to reach target end_value.
104
-
105
- Uses linear warmup for the first warmup_steps steps.
106
-
107
- Args:
108
- step:
109
- Current step number.
110
- max_steps:
111
- Total number of steps.
112
- start_value:
113
- Starting value.
114
- end_value:
115
- Target value.
116
- warmup_steps:
117
- Number of steps for warmup.
118
- warmup_start_value:
119
- Starting value for warmup.
120
- warmup_end_value:
121
- Target value for warmup. Defaults to start_value.
122
- period:
123
- The number of steps over which the cosine function completes a full cycle.
124
- Defaults to max_steps - warmup_steps.
125
- Returns:
126
- Cosine decay value.
127
- """
128
- if warmup_steps < 0:
129
- raise ValueError(f"Warmup steps {warmup_steps} can't be negative.")
130
- if warmup_steps > max_steps:
131
- raise ValueError(f"Warmup steps {warmup_steps} must be <= max_steps.")
132
- if step > max_steps:
133
- warnings.warn(
134
- f"Current step number {step} exceeds max_steps {max_steps}.",
135
- category=RuntimeWarning,
136
- )
137
-
138
- if warmup_end_value is None:
139
- warmup_end_value = start_value
140
-
141
- if step < warmup_steps:
142
- # Use step + 1 to reach warmup_end_value at end of warmup. This means that the
143
- # initial warmup_start_value is skipped which is oftentimes desired when setting
144
- # it to 0 as this would result in no parameter updates.
145
- return (
146
- warmup_start_value
147
- + (warmup_end_value - warmup_start_value) * (step + 1) / warmup_steps
148
- )
149
- else:
150
- max_steps = max_steps - warmup_steps if period is None else 1
151
- return cosine_schedule(
152
- step=step - warmup_steps,
153
- max_steps=max_steps,
154
- start_value=start_value,
155
- end_value=end_value,
156
- period=period,
157
- )
158
-
159
-
160
- class CosineWarmupScheduler(torch.optim.lr_scheduler.LambdaLR):
161
- """Cosine warmup scheduler for learning rate.
162
-
163
- Args:
164
- optimizer:
165
- Optimizer object to schedule the learning rate.
166
- warmup_epochs:
167
- Number of warmup epochs or steps.
168
- max_epochs:
169
- Total number of training epochs or steps.
170
- last_epoch:
171
- The index of last epoch or step.
172
- start_value:
173
- Starting learning rate.
174
- end_value:
175
- Target learning rate.
176
- verbose:
177
- If True, prints a message to stdout for each update.
178
- warmup_start_value:
179
- Starting learning rate for warmup.
180
- warmup_end_value:
181
- Target learning rate for warmup. Defaults to start_value.
182
-
183
- Note: The `epoch` arguments do not necessarily have to be epochs. Any step or index
184
- can be used. The naming follows the PyTorch convention to use `epoch` for the steps
185
- in the scheduler.
186
- """
187
-
188
- def __init__(
189
- self,
190
- optimizer: torch.optim.Optimizer,
191
- warmup_epochs: int,
192
- max_epochs: int,
193
- last_epoch: int = -1,
194
- start_value: float = 1.0,
195
- end_value: float = 0.001,
196
- period: Optional[int] = None,
197
- verbose: bool = False,
198
- warmup_start_value: float = 0.0,
199
- warmup_end_value: Optional[float] = None,
200
- ) -> None:
201
- self.warmup_epochs = warmup_epochs
202
- self.max_epochs = max_epochs
203
- self.start_value = start_value
204
- self.end_value = end_value
205
- self.period = period
206
- self.warmup_start_value = warmup_start_value
207
- self.warmup_end_value = warmup_end_value
208
-
209
- super().__init__(
210
- optimizer=optimizer,
211
- lr_lambda=self.scale_lr,
212
- last_epoch=last_epoch,
213
- verbose=verbose,
214
- )
215
-
216
- def scale_lr(self, epoch: int) -> float:
217
- """Scale learning rate according to the current epoch number.
218
-
219
- Args:
220
- epoch:
221
- Current epoch number.
222
-
223
- Returns:
224
- Scaled learning rate.
225
-
226
- """
227
- return cosine_warmup_schedule(
228
- step=epoch,
229
- max_steps=self.max_epochs,
230
- start_value=self.start_value,
231
- end_value=self.end_value,
232
- warmup_steps=self.warmup_epochs,
233
- warmup_start_value=self.warmup_start_value,
234
- warmup_end_value=self.warmup_end_value,
235
- period=self.period,
236
- )