Feature Extraction
Transformers
Safetensors
English
spectre
medical-imaging
ct-scan
3d
vision-transformer
self-supervised-learning
foundation-model
radiology
custom_code
Instructions to use cclaess/SPECTRE-Large with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use cclaess/SPECTRE-Large with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="cclaess/SPECTRE-Large", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("cclaess/SPECTRE-Large", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import Type, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.jit import Final | |
| from timm.layers import use_fused_attn | |
| from spectre.models.layers.rotary_pos_embed import rope_apply | |
| class Attention(nn.Module): | |
| fused_attn: Final[bool] | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| mode: str = "mha", | |
| q_proj_dim: Optional[int] = None, | |
| kv_proj_dim: Optional[int] = None, | |
| qkv_bias: bool = False, | |
| qk_norm: bool = False, | |
| proj_bias: bool = True, | |
| attn_drop: float = 0., | |
| proj_drop: float = 0., | |
| norm_layer: Type[nn.Module] = nn.LayerNorm, | |
| ) -> None: | |
| super().__init__() | |
| assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.fused_attn = use_fused_attn() | |
| self.mode = mode.lower() | |
| assert self.mode in ["mha", "mqa", "mla"], "Attention mode must be 'mha', 'mqa', or 'mla'" | |
| assert not (self.mode == "mla" and kv_proj_dim is None), "kv_proj_dim must be provided for 'mla' mode" | |
| assert not (self.mode == "mla" and q_proj_dim is None), "q_proj_dim must be provided for 'mla' mode" | |
| if self.mode == "mha": | |
| self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head | |
| self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| elif self.mode == "mqa": | |
| self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.kv = nn.Linear(dim, 2 * self.head_dim, bias=qkv_bias) # Key and value pair shared across heads | |
| self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| elif self.mode == "mla": | |
| self.q_proj = nn.Linear(dim, q_proj_dim, bias=qkv_bias) # Projected query for every head | |
| self.kv_proj = nn.Linear(dim, kv_proj_dim, bias=qkv_bias) # Projected key and value pair for every head | |
| self.q_norm = norm_layer(q_proj_dim) if qk_norm else nn.Identity() | |
| self.kv_norm = norm_layer(kv_proj_dim) if qk_norm else nn.Identity() | |
| self.q = nn.Linear(q_proj_dim, dim, bias=qkv_bias) # Query for every head | |
| self.kv = nn.Linear(kv_proj_dim, 2 * dim, bias=qkv_bias) # Key and value pair for every head | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim, bias=proj_bias) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def apply_rotary_pos_emb( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| rope: Tuple[torch.Tensor, torch.Tensor], | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Apply RoPE to the query and key tensors. | |
| Args: | |
| q (torch.Tensor): Query tensor of shape (B, num_heads, N, head_dim) | |
| k (torch.Tensor): Key tensor of shape (B, num_heads, N, head_dim) | |
| rope (Tuple[torch.Tensor, torch.Tensor]): Tuple of (sin, cos) tensors for RoPE application. | |
| Sin and cos can be of shape | |
| """ | |
| # Match dtype to rope for numeric stability | |
| q_dtype, k_dtype = q.dtype, k.dtype | |
| sin, cos = rope | |
| if sin.ndim == 2: | |
| sin = sin.unsqueeze(0).unsqueeze(0) | |
| cos = cos.unsqueeze(0).unsqueeze(0) | |
| elif sin.ndim == 3: | |
| sin = sin.unsqueeze(1) | |
| cos = cos.unsqueeze(1) | |
| else: | |
| raise ValueError("RoPE sin/cos must be of shape [N, head_dim] or [B, N, head_dim]") | |
| rope_dtype = sin.dtype | |
| q = q.to(dtype=rope_dtype) | |
| k = k.to(dtype=rope_dtype) | |
| N = q.shape[-2] # total tokens per sample | |
| N_spatial = sin.shape[-2] # number of spatial tokens covered by rope | |
| prefix = N - N_spatial # e.g., [cls] or [reg] tokens at the front | |
| assert prefix >= 0, "RoPE sin/cos length exceeds sequence length" | |
| if prefix > 0: | |
| q_prefix = q[:, :, :prefix, :] | |
| k_prefix = k[:, :, :prefix, :] | |
| q_spatial = q[:, :, prefix:, :] | |
| k_spatial = k[:, :, prefix:, :] | |
| else: | |
| q_prefix = k_prefix = None | |
| q_spatial, k_spatial = q, k | |
| # Apply RoPE on the spatial tail | |
| q_spatial = rope_apply(q_spatial, sin, cos) | |
| k_spatial = rope_apply(k_spatial, sin, cos) | |
| # Stitch back | |
| if prefix > 0: | |
| q = torch.cat((q_prefix, q_spatial), dim=-2) | |
| k = torch.cat((k_prefix, k_spatial), dim=-2) | |
| else: | |
| q, k = q_spatial, k_spatial | |
| # Cast back to original dtypes | |
| q = q.to(dtype=q_dtype) | |
| k = k.to(dtype=k_dtype) | |
| return q, k | |
| def compute_qkv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| B, N, _ = x.shape | |
| if self.mode == "mha": | |
| q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) | |
| kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| k, v = kv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| elif self.mode == "mqa": | |
| q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) | |
| kv = self.kv(x).reshape(B, N, 2, 1, self.head_dim).permute(2, 0, 3, 1, 4) | |
| kv = kv.expand(-1, -1, self.num_heads, -1, -1) # Expand to match num_heads | |
| k, v = kv.unbind(0) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| elif self.mode == "mla": | |
| q = self.q_proj(x) | |
| kv = self.kv_proj(x) | |
| q, kv = self.q_norm(q), self.kv_norm(kv) # Normalization on projections | |
| q = self.q(q).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) | |
| kv = self.kv(kv).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
| k, v = kv.unbind(0) | |
| return q, k, v | |
| def compute_attention( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| ) -> torch.Tensor: | |
| B, _, N, _ = q.shape | |
| C = self.num_heads * self.head_dim | |
| if self.fused_attn: | |
| x = F.scaled_dot_product_attention( | |
| q, k, v, | |
| dropout_p=self.attn_drop.p if self.training else 0., | |
| ) | |
| else: | |
| q = q * self.scale | |
| attn = q @ k.transpose(-2, -1) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = attn @ v | |
| return x.transpose(1, 2).reshape(B, N, C) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| rope = None, | |
| ) -> torch.Tensor: | |
| q, k, v = self.compute_qkv(x) | |
| if rope is not None: | |
| if isinstance(rope, list): | |
| rope = tuple(torch.stack([r[i] for r in rope], dim=0) for i in range(2)) | |
| q, k = self.apply_rotary_pos_emb(q, k, rope) | |
| x = self.compute_attention(q, k, v) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |