Feature Extraction
Transformers
Safetensors
English
spectre
medical-imaging
ct-scan
3d
vision-transformer
self-supervised-learning
foundation-model
radiology
custom_code
Instructions to use cclaess/SPECTRE-Large with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use cclaess/SPECTRE-Large with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="cclaess/SPECTRE-Large", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("cclaess/SPECTRE-Large", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| from functools import partial | |
| from urllib.parse import urlparse | |
| from typing import ( | |
| Tuple, Union, Callable, Literal, | |
| Optional, Type, Set, List, Dict, Any, | |
| ) | |
| import torch | |
| import torch.nn as nn | |
| from timm.layers import PatchDropout, AttentionPoolLatent | |
| from timm.models.vision_transformer import LayerScale, DropPath, Mlp | |
| from huggingface_hub import hf_hub_download, load_state_dict_from_file | |
| from spectre.models.layers import ( | |
| PatchEmbed, | |
| Attention, | |
| RotaryPositionEmbedding, | |
| ) | |
| from spectre.utils import ( | |
| resample_abs_pos_embed, | |
| feature_take_indices, | |
| global_pool_nlc, | |
| ) | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| attn_mode: str = 'mha', | |
| q_proj_dim: Optional[int] = None, | |
| kv_proj_dim: Optional[int] = None, | |
| mlp_ratio: float = 4., | |
| qkv_bias: bool = False, | |
| qk_norm: bool = False, | |
| proj_bias: bool = True, | |
| proj_drop: float = 0., | |
| attn_drop: float = 0., | |
| init_values: Optional[float] = None, | |
| drop_path: float = 0., | |
| act_layer: Type[nn.Module] = nn.GELU, | |
| norm_layer: Type[nn.Module] = nn.LayerNorm, | |
| mlp_layer: Type[nn.Module] = Mlp, | |
| ) -> None: | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| self.attn = Attention( | |
| dim, | |
| num_heads=num_heads, | |
| mode=attn_mode, | |
| q_proj_dim=q_proj_dim, | |
| kv_proj_dim=kv_proj_dim, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| proj_bias=proj_bias, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| norm_layer=norm_layer, | |
| ) | |
| self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
| self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| self.mlp = mlp_layer( | |
| in_features=dim, | |
| hidden_features=int(dim * mlp_ratio), | |
| act_layer=act_layer, | |
| bias=proj_bias, | |
| drop=proj_drop, | |
| ) | |
| self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() | |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| rope = None | |
| ) -> torch.Tensor: | |
| x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), rope=rope))) | |
| x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) | |
| return x | |
| class VisionTransformer(nn.Module): | |
| """ Vision Transformer with 3D Patch Embedding | |
| """ | |
| def __init__( | |
| self, | |
| img_size: Union[int, Tuple[int, int, int]] = (128, 128, 64), | |
| patch_size: Union[int, Tuple[int, int, int]] = (16, 16, 8), | |
| in_chans: int = 1, | |
| num_classes: int = 1000, | |
| global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token', | |
| embed_dim: int = 768, | |
| depth: int = 12, | |
| num_heads: int = 12, | |
| attn_mode: str = 'mha', | |
| q_proj_dim: Optional[int] = None, | |
| kv_proj_dim: Optional[int] = None, | |
| mlp_ratio: float = 4., | |
| qkv_bias: bool = True, | |
| qk_norm: bool = False, | |
| proj_bias: bool = True, | |
| init_values: Optional[float] = None, | |
| class_token: bool = True, | |
| pos_embed: str = 'learn', | |
| no_embed_class: bool = False, | |
| rope_kwargs: Optional[dict] = None, | |
| reg_tokens: int = 0, | |
| pre_norm: bool = False, | |
| final_norm: bool = True, | |
| fc_norm: Optional[bool] = None, | |
| dynamic_img_size: bool = False, | |
| dynamic_img_pad: bool = False, | |
| drop_rate: float = 0., | |
| pos_drop_rate: float = 0., | |
| patch_drop_rate: float = 0., | |
| proj_drop_rate: float = 0., | |
| attn_drop_rate: float = 0., | |
| drop_path_rate: float = 0., | |
| embed_layer: Callable = PatchEmbed, | |
| embed_norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None, | |
| norm_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None, | |
| act_layer: Optional[Union[Callable, Type[torch.nn.Module]]] = None, | |
| block_fn: Type[nn.Module] = Block, | |
| mlp_layer: Type[nn.Module] = Mlp, | |
| ) -> None: | |
| """ | |
| Args: | |
| img_size: Input image size. | |
| patch_size: Patch size. | |
| in_chans: Number of image input channels. | |
| num_classes: Number of classes for classification head. | |
| global_pool: Type of global pooling for final sequence (default: 'token'). | |
| embed_dim: Transformer embedding dimension. | |
| depth: Depth of transformer. | |
| num_heads: Number of attention heads. | |
| attn_mode: Attention mode ('mha', 'mqa', 'mla'). | |
| q_proj_dim: Query projection dimension for 'mla' mode. | |
| kv_proj_dim: Key, value projection dimension for 'mla' mode. | |
| mlp_ratio: Ratio of mlp hidden dim to embedding dim. | |
| qkv_bias: Enable bias for qkv projections if True. | |
| init_values: Layer-scale init values (layer-scale enabled if not None). | |
| class_token: Use class token. | |
| pos_embed: Type of position embedding to use (default: 'learn'). | |
| no_embed_class: Don't include position embeddings for class (or reg) tokens for learnable pos_embed. | |
| rope_kwargs: Additional arguments for rotary position embedding. | |
| reg_tokens: Number of register tokens. | |
| pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). | |
| final_norm: Enable norm after transformer blocks, before head (standard in most ViT). | |
| fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. | |
| drop_rate: Head dropout rate. | |
| pos_drop_rate: Position embedding dropout rate. | |
| attn_drop_rate: Attention dropout rate. | |
| drop_path_rate: Stochastic depth rate. | |
| weight_init: Weight initialization scheme. | |
| fix_init: Apply weight initialization fix (scaling w/ layer index). | |
| embed_layer: Patch embedding layer. | |
| embed_norm_layer: Normalization layer to use / override in patch embed module. | |
| norm_layer: Normalization layer. | |
| act_layer: MLP activation layer. | |
| block_fn: Transformer block layer. | |
| """ | |
| super().__init__() | |
| assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') | |
| assert class_token or global_pool != 'token' | |
| assert pos_embed in ('', 'none', 'learn', 'rope') | |
| assert attn_mode in ('mha', 'mqa', 'mla') | |
| rope_kwargs = {} if rope_kwargs is None else dict(rope_kwargs) | |
| rope_kwargs.setdefault("dtype", torch.float32) # robust with mixed-precision | |
| use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm | |
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |
| embed_norm_layer = embed_norm_layer | |
| act_layer = act_layer or nn.GELU | |
| self.num_classes = num_classes | |
| self.global_pool = global_pool | |
| self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models | |
| self.num_prefix_tokens = 1 if class_token else 0 | |
| self.num_prefix_tokens += reg_tokens | |
| self.num_reg_tokens = reg_tokens | |
| self.has_class_token = class_token | |
| self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) | |
| self.dynamic_img_size = dynamic_img_size | |
| embed_args = {} | |
| if self.dynamic_img_size: | |
| # flatten deferred until after pos embed | |
| embed_args.update(dict(strict_img_size=False, output_fmt="NHWDC")) | |
| elif pos_embed == 'rope': | |
| embed_args['output_fmt'] = "NHWDC" | |
| if embed_norm_layer is not None: | |
| embed_args['norm_layer'] = embed_norm_layer | |
| self.patch_embed = embed_layer( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) | |
| dynamic_img_pad=dynamic_img_pad, | |
| **embed_args, | |
| ) | |
| num_patches = self.patch_embed.num_patches | |
| reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None | |
| self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None | |
| embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens | |
| self.pos_embed, self.rope, self.requires_per_sample_rope = None, None, False | |
| if pos_embed == 'learn': | |
| self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) | |
| if pos_embed == 'rope': | |
| self.rope = RotaryPositionEmbedding( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| **rope_kwargs, | |
| ) | |
| self.requires_per_sample_rope = any([ | |
| self.rope.shift_coords is not None, | |
| self.rope.jitter_coords is not None, | |
| self.rope.rescale_coords is not None, | |
| ]) | |
| self.pos_drop = nn.Dropout(p=pos_drop_rate) | |
| if patch_drop_rate > 0: | |
| self.patch_drop = PatchDropout( | |
| patch_drop_rate, | |
| num_prefix_tokens=self.num_prefix_tokens, | |
| ) | |
| else: | |
| self.patch_drop = nn.Identity() | |
| self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() | |
| dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule | |
| self.blocks = nn.Sequential(*[ | |
| block_fn( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| attn_mode=attn_mode, | |
| q_proj_dim=q_proj_dim, | |
| kv_proj_dim=kv_proj_dim, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_norm=qk_norm, | |
| proj_bias=proj_bias, | |
| init_values=init_values, | |
| proj_drop=proj_drop_rate, | |
| attn_drop=attn_drop_rate, | |
| drop_path=dpr[i], | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| mlp_layer=mlp_layer, | |
| ) | |
| for i in range(depth)]) | |
| self.feature_info = [ | |
| dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)] | |
| self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity() | |
| # Classifier Head | |
| if global_pool == 'map': | |
| self.attn_pool = AttentionPoolLatent( | |
| self.embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| ) | |
| else: | |
| self.attn_pool = None | |
| self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity() | |
| self.head_drop = nn.Dropout(drop_rate) | |
| self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
| self.init_weights() | |
| def init_weights(self) -> None: | |
| if self.pos_embed is not None and not self.pos_embed.is_meta: | |
| nn.init.trunc_normal_(self.pos_embed, std=.02) | |
| if self.cls_token is not None and not self.cls_token.is_meta: | |
| nn.init.normal_(self.cls_token, std=1e-6) | |
| if self.reg_token is not None and not self.reg_token.is_meta: | |
| nn.init.normal_(self.reg_token, std=1e-6) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m: nn.Module) -> None: | |
| # this fn left here for compat with downstream users | |
| if isinstance(m, nn.Linear): | |
| if not m.weight.is_meta: | |
| nn.init.trunc_normal_(m.weight, std=.02) | |
| if m.bias is not None and not m.bias.is_meta: | |
| nn.init.zeros_(m.bias) | |
| def no_weight_decay(self) -> Set: | |
| return {'pos_embed', 'cls_token', 'dist_token'} | |
| def get_classifier(self) -> nn.Module: | |
| return self.head | |
| def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): | |
| self.num_classes = num_classes | |
| if global_pool is not None: | |
| assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') | |
| if global_pool == 'map' and self.attn_pool is None: | |
| assert False, "Cannot currently add attention pooling in reset_classifier()." | |
| elif global_pool != 'map' and self.attn_pool is not None: | |
| self.attn_pool = None # remove attention pooling | |
| self.global_pool = global_pool | |
| self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
| def set_input_size( | |
| self, | |
| img_size: Optional[Tuple[int, int, int]] = None, | |
| patch_size: Optional[Tuple[int, int, int]] = None, | |
| ): | |
| """Method updates the input image resolution, patch size | |
| Args: | |
| img_size: New input resolution, if None current resolution is used | |
| patch_size: New patch size, if None existing patch size is used | |
| """ | |
| prev_grid_size = self.patch_embed.grid_size | |
| self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) | |
| if self.pos_embed is not None: | |
| num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens | |
| num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens | |
| if num_new_tokens != self.pos_embed.shape[1]: | |
| self.pos_embed = nn.Parameter(resample_abs_pos_embed( | |
| self.pos_embed, | |
| new_size=self.patch_embed.grid_size, | |
| old_size=prev_grid_size, | |
| num_prefix_tokens=num_prefix_tokens, | |
| verbose=True, | |
| )) | |
| def _pos_embed(self, x: torch.Tensor): | |
| if self.pos_embed is None and self.rope is None: | |
| x = x.view(x.shape[0], -1, x.shape[-1]) | |
| if self.reg_token is not None: | |
| x = torch.cat([self.reg_token.expand(x.shape[0], -1, -1), x], dim=1) | |
| if self.cls_token is not None: | |
| x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1) | |
| return x, None | |
| if self.dynamic_img_size or self.rope is not None: | |
| B, H, W, D, C = x.shape | |
| x = x.view(B, -1, C) | |
| pos_embed, rope = None, None | |
| if self.pos_embed is not None: | |
| if self.dynamic_img_size: | |
| prev_grid_size = self.patch_embed.grid_size | |
| pos_embed = resample_abs_pos_embed( | |
| self.pos_embed, | |
| new_size=(H, W, D), | |
| old_size=prev_grid_size, | |
| num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, | |
| ) | |
| else: | |
| pos_embed = self.pos_embed | |
| if self.rope is not None: | |
| if self.requires_per_sample_rope: | |
| rope = [self.rope(H=H, W=W, D=D) for _ in range(B)] | |
| else: | |
| rope = self.rope(H=H, W=W, D=D) | |
| to_cat = [] | |
| if self.cls_token is not None: | |
| to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) | |
| if self.reg_token is not None: | |
| to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) | |
| if self.no_embed_class: | |
| # deit-3, updated JAX (big vision) | |
| # position embedding does not overlap with class token, add then concat | |
| if pos_embed is not None: | |
| x = x + pos_embed | |
| if to_cat: | |
| x = torch.cat(to_cat + [x], dim=1) | |
| else: | |
| # original timm, JAX, and deit vit impl | |
| # pos_embed has entry for class token, concat then add | |
| if to_cat: | |
| x = torch.cat(to_cat + [x], dim=1) | |
| if pos_embed is not None: | |
| x = x + pos_embed | |
| return self.pos_drop(x), rope | |
| def forward_intermediates( | |
| self, | |
| x: torch.Tensor, | |
| indices: Optional[Union[int, List[int]]] = None, | |
| return_prefix_tokens: bool = False, | |
| norm: bool = False, | |
| stop_early: bool = False, | |
| output_fmt: str = 'NCHWD', | |
| intermediates_only: bool = False, | |
| output_dict: bool = False, | |
| ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]: | |
| """ Forward features that returns intermediates. | |
| Args: | |
| x: Input image tensor | |
| indices: Take last n blocks if int, all if None, select matching indices if sequence | |
| return_prefix_tokens: Return both prefix and spatial intermediate tokens | |
| norm: Apply norm layer to all intermediates | |
| stop_early: Stop iterating over blocks when last desired intermediate hit | |
| output_fmt: Shape of intermediate feature outputs | |
| intermediates_only: Only return intermediate features | |
| output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys | |
| Returns: | |
| A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing | |
| 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix') | |
| """ | |
| assert output_fmt in ('NCHWD', 'NLC'), 'Output format must be one of NCHWD or NLC.' | |
| reshape = output_fmt == 'NCHWD' | |
| intermediates = [] | |
| take_indices, max_index = feature_take_indices(len(self.blocks), indices) | |
| # forward pass | |
| B, _, height, width, depth = x.shape | |
| x = self.patch_embed(x) | |
| x, rope = self._pos_embed(x) | |
| x = self.patch_drop(x) | |
| x = self.norm_pre(x) | |
| if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript | |
| blocks = self.blocks | |
| else: | |
| blocks = self.blocks[:max_index + 1] | |
| for i, blk in enumerate(blocks): | |
| x = blk(x, rope=rope) | |
| if i in take_indices: | |
| # normalize intermediates with final norm layer if enabled | |
| intermediates.append(self.norm(x) if norm else x) | |
| # process intermediates | |
| if self.num_prefix_tokens: | |
| # split prefix (e.g. class, distill) and spatial feature tokens | |
| prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] | |
| intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] | |
| else: | |
| prefix_tokens = None | |
| if reshape: | |
| # reshape to BCHW output format | |
| H, W, D = self.patch_embed.dynamic_feat_size((height, width, depth)) | |
| intermediates = [y.reshape(B, H, W, D, -1).permute(0, 4, 1, 2, 3).contiguous() for y in intermediates] | |
| if output_dict: | |
| result_dict = {} | |
| # Intermediates are always included | |
| result_dict['image_intermediates'] = intermediates | |
| if prefix_tokens is not None and return_prefix_tokens: | |
| result_dict['image_intermediates_prefix'] = prefix_tokens | |
| # Only include features if not intermediates_only | |
| if not intermediates_only: | |
| x_final = self.norm(x) | |
| result_dict['image_features'] = x_final | |
| return result_dict | |
| # For non-dictionary output, maintain the original behavior | |
| if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: | |
| # return_prefix not support in torchscript due to poor type handling | |
| intermediates = list(zip(intermediates, prefix_tokens)) | |
| if intermediates_only: | |
| return intermediates | |
| x = self.norm(x) | |
| return x, intermediates | |
| def prune_intermediate_layers( | |
| self, | |
| indices: Union[int, List[int]] = 1, | |
| prune_norm: bool = False, | |
| prune_head: bool = True, | |
| ): | |
| """Prune layers not required for specified intermediates. | |
| Args: | |
| indices: Indices of intermediate layers to keep. | |
| prune_norm: Whether to prune normalization layer. | |
| prune_head: Whether to prune the classifier head. | |
| Returns: | |
| List of indices that were kept. | |
| """ | |
| take_indices, max_index = feature_take_indices(len(self.blocks), indices) | |
| self.blocks = self.blocks[:max_index + 1] # truncate blocks | |
| if prune_norm: | |
| self.norm = nn.Identity() | |
| if prune_head: | |
| self.fc_norm = nn.Identity() | |
| self.reset_classifier(0, '') | |
| return take_indices | |
| def get_intermediate_layers( | |
| self, | |
| x: torch.Tensor, | |
| n: Union[int, List[int], Tuple[int]] = 1, | |
| reshape: bool = False, | |
| return_prefix_tokens: bool = False, | |
| norm: bool = False, | |
| ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: | |
| """Get intermediate layer outputs (DINO interface compatibility). | |
| NOTE: This API is for backwards compat, favour using forward_intermediates() directly. | |
| Args: | |
| x: Input tensor. | |
| n: Number or indices of layers. | |
| reshape: Reshape to NCHWD format. | |
| return_prefix_tokens: Return prefix tokens. | |
| norm: Apply normalization. | |
| Returns: | |
| List of intermediate features. | |
| """ | |
| return self.forward_intermediates( | |
| x, n, | |
| return_prefix_tokens=return_prefix_tokens, | |
| norm=norm, | |
| output_fmt='NCHWD' if reshape else 'NLC', | |
| intermediates_only=True, | |
| ) | |
| def forward_features(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm).""" | |
| x = self.patch_embed(x) | |
| x, rope = self._pos_embed(x) | |
| x = self.patch_drop(x) | |
| x = self.norm_pre(x) | |
| for blk in self.blocks: | |
| x = blk(x, rope=rope) | |
| x = self.norm(x) | |
| return x | |
| def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: | |
| """Apply pooling to feature tokens. | |
| Args: | |
| x: Feature tensor. | |
| pool_type: Pooling type override. | |
| Returns: | |
| Pooled features. | |
| """ | |
| if self.attn_pool is not None: | |
| x = self.attn_pool(x) | |
| return x | |
| pool_type = self.global_pool if pool_type is None else pool_type | |
| x = global_pool_nlc( | |
| x, | |
| pool_type=pool_type, | |
| num_prefix_tokens=self.num_prefix_tokens, | |
| ) | |
| return x | |
| def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: | |
| """Forward pass through classifier head. | |
| Args: | |
| x: Feature tensor. | |
| pre_logits: Return features before final classifier. | |
| Returns: | |
| Output tensor. | |
| """ | |
| x = self.pool(x) | |
| x = self.fc_norm(x) | |
| x = self.head_drop(x) | |
| return x if pre_logits else self.head(x) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.forward_features(x) | |
| x = self.forward_head(x) | |
| return x | |
| def from_pretrained( | |
| cls, | |
| checkpoint_path_or_url: Union[str, os.PathLike], | |
| verbose: bool = True, | |
| **kwargs | |
| ) -> 'VisionTransformer': | |
| """Load pretrained model weights from a local path or a URL.""" | |
| model = cls(**kwargs) | |
| def _is_url(path: str) -> bool: | |
| try: | |
| parsed = urlparse(str(path)) | |
| return parsed.scheme in ('http', 'https') | |
| except Exception: | |
| return False | |
| def _is_hf_url(path: str) -> bool: | |
| try: | |
| parsed = urlparse(str(path)) | |
| return 'huggingface.co' in parsed.netloc | |
| except Exception: | |
| return False | |
| if _is_hf_url(checkpoint_path_or_url): | |
| if verbose: | |
| print(f"Downloading pretrained weights from Hugging Face URL: {checkpoint_path_or_url}") | |
| # Extract repo_id and filename from the URL | |
| parsed = urlparse(checkpoint_path_or_url) | |
| parts = parsed.path.strip('/').split('/') | |
| repo_id = '/'.join(parts[:2]) # e.g., 'cclaess/SPECTRE' | |
| filename = parts[-1] # e.g., 'spectre_backbone_vit_large_patch16_128.pt' | |
| local_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| state_dict = load_state_dict_from_file(local_path, map_location='cpu') | |
| elif _is_url(checkpoint_path_or_url): | |
| if verbose: | |
| print(f"Downloading pretrained weights from URL: {checkpoint_path_or_url}") | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| checkpoint_path_or_url, map_location='cpu', weights_only=False, progress=verbose) | |
| else: | |
| local_path = os.fspath(checkpoint_path_or_url) | |
| if not os.path.exists(local_path): | |
| raise FileNotFoundError(f"Checkpoint file not found: {local_path}") | |
| if verbose: | |
| print(f"Loading checkpoint from local path: {local_path}") | |
| state_dict = torch.load(local_path, map_location='cpu', weights_only=False) | |
| msg = model.load_state_dict(state_dict, strict=False) | |
| if verbose: | |
| print(f"Loaded pretrained weights with msg: {msg}") | |
| return model | |
| def vit_tiny_patch16_128( | |
| checkpoint_path_or_url: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Tiny model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64]. | |
| """ | |
| kwargs = dict( | |
| img_size=(128, 128, 64), | |
| patch_size=(16, 16, 8), | |
| embed_dim=192, | |
| depth=12, | |
| num_heads=2, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if checkpoint_path_or_url is not None: | |
| return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_small_patch16_128( | |
| checkpoint_path_or_url: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Small model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64]. | |
| """ | |
| kwargs = dict( | |
| img_size=(128, 128, 64), | |
| patch_size=(16, 16, 8), | |
| embed_dim=384, | |
| depth=12, | |
| num_heads=4, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if checkpoint_path_or_url is not None: | |
| return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_base_patch16_128( | |
| checkpoint_path_or_url: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Base model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64]. | |
| """ | |
| kwargs = dict( | |
| img_size=(128, 128, 64), | |
| patch_size=(16, 16, 8), | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if checkpoint_path_or_url is not None: | |
| return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_base_patch16_256( | |
| pretrained_weights: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Base model with 3D patch embedding, patch size [16, 16, 8] and input size [256, 256, 128]. | |
| """ | |
| kwargs = dict( | |
| img_size=(256, 256, 128), | |
| patch_size=(16, 16, 8), | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if pretrained_weights is not None: | |
| return VisionTransformer.from_pretrained(pretrained_weights, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_base_patch32_128( | |
| checkpoint_path_or_url: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Base model with 3D patch embedding, patch size [32, 32, 16] and input size [128, 128, 64]. | |
| """ | |
| kwargs = dict( | |
| img_size=(128, 128, 64), | |
| patch_size=(32, 32, 16), | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=8, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if checkpoint_path_or_url is not None: | |
| return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_large_patch16_128( | |
| checkpoint_path_or_url: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64]. | |
| """ | |
| kwargs = dict( | |
| img_size=(128, 128, 64), | |
| patch_size=(16, 16, 8), | |
| embed_dim=1080, | |
| depth=24, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if checkpoint_path_or_url is not None: | |
| return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_large_patch16_256( | |
| pretrained_weights: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [128, 128, 64]. | |
| """ | |
| kwargs = dict( | |
| img_size=(256, 256, 128), | |
| patch_size=(16, 16, 8), | |
| embed_dim=1080, | |
| depth=24, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if pretrained_weights is not None: | |
| return VisionTransformer.from_pretrained(pretrained_weights, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_large_patch16_320( | |
| pretrained_weights: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Large model with 3D patch embedding, patch size [16, 16, 8] and input size [320, 320, 128]. | |
| """ | |
| kwargs = dict( | |
| img_size=(320, 320, 128), | |
| patch_size=(16, 16, 8), | |
| embed_dim=1080, | |
| depth=24, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if pretrained_weights is not None: | |
| return VisionTransformer.from_pretrained(pretrained_weights, **kwargs) | |
| return VisionTransformer(**kwargs) | |
| def vit_large_patch32_128( | |
| checkpoint_path_or_url: Optional[str] = None, | |
| **kwargs | |
| ) -> VisionTransformer: | |
| """ViT-Large model with 3D patch embedding, patch size [32, 32, 16] and input size [128, 128, 64]. | |
| """ | |
| kwargs = dict( | |
| img_size=(128, 128, 64), | |
| patch_size=(32, 32, 16), | |
| embed_dim=1080, | |
| depth=24, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=nn.LayerNorm, | |
| **kwargs, | |
| ) | |
| if checkpoint_path_or_url is not None: | |
| return VisionTransformer.from_pretrained(checkpoint_path_or_url, **kwargs) | |
| return VisionTransformer(**kwargs) | |