| import albumentations as A |
| import cv2 |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from numpy.typing import NDArray |
| from transformers import PreTrainedModel |
| from timm import create_model |
| from typing import Optional |
| from .configuration import BoneAgeConfig |
|
|
|
|
| class GeM(nn.Module): |
| def __init__( |
| self, p: int = 3, eps: float = 1e-6, dim: int = 2, flatten: bool = True |
| ): |
| super().__init__() |
| self.p = nn.Parameter(torch.ones(1) * p) |
| self.eps = eps |
| assert dim in {2, 3}, f"dim must be one of [2, 3], not {dim}" |
| self.dim = dim |
| if self.dim == 2: |
| self.func = F.adaptive_avg_pool2d |
| elif self.dim == 3: |
| self.func = F.adaptive_avg_pool3d |
| self.flatten = nn.Flatten(1) if flatten else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = self.func(x.clamp(min=self.eps).pow(self.p), output_size=1).pow( |
| 1.0 / self.p |
| ) |
| return self.flatten(x) |
|
|
|
|
| class BoneAgeModel(nn.Module): |
| def __init__( |
| self, backbone, feature_dim=768, dropout=0.1, num_classes=240, in_chans=2 |
| ): |
| super().__init__() |
| self.backbone = create_model( |
| model_name=backbone, |
| pretrained=False, |
| num_classes=0, |
| global_pool="", |
| features_only=False, |
| in_chans=in_chans, |
| ) |
| self.pooling = GeM(p=3, dim=2) |
| self.dropout = nn.Dropout(p=dropout) |
| self.linear = nn.Linear(feature_dim, num_classes) |
|
|
| def normalize(self, x: torch.Tensor) -> torch.Tensor: |
| |
| mini, maxi = 0.0, 255.0 |
| x = (x - mini) / (maxi - mini) |
| x = (x - 0.5) * 2.0 |
| return x |
|
|
| def forward( |
| self, x: torch.Tensor, female: torch.Tensor, return_logits: bool = False |
| ) -> torch.Tensor: |
| assert x.size(0) == female.size( |
| 0 |
| ), f"x.size(0) [{x.size(0)}] must equal female.size(0) [{female.size(0)}]" |
| female_ch = torch.zeros_like(x).to(x.device) |
| female_ch[female.bool()] = 255.0 |
| x = torch.cat([x, female_ch], dim=1) |
| x = self.normalize(x) |
| features = self.pooling(self.backbone(x)) |
| logits = self.linear(features) |
| if return_logits: |
| return logits |
| out = (logits.softmax(1) * torch.arange(logits.size(1)).to(logits.device)).sum( |
| 1 |
| ) |
| return out |
|
|
|
|
| class BoneAgeEnsembleModel(PreTrainedModel): |
| config_class = BoneAgeConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.num_models = config.num_models |
| for i in range(self.num_models): |
| setattr( |
| self, |
| f"net{i}", |
| BoneAgeModel( |
| config.backbone, |
| config.feature_dim, |
| config.dropout, |
| config.num_classes, |
| config.in_chans, |
| ), |
| ) |
|
|
| @staticmethod |
| def load_image_from_dicom(path: str) -> Optional[NDArray]: |
| try: |
| from pydicom import dcmread |
| from pydicom.pixels import apply_voi_lut |
| except ModuleNotFoundError: |
| print("`pydicom` is not installed, returning None ...") |
| return None |
| dicom = dcmread(path) |
| arr = apply_voi_lut(dicom.pixel_array, dicom) |
| if dicom.PhotometricInterpretation == "MONOCHROME1": |
| arr = arr.max() - arr |
|
|
| arr = arr - arr.min() |
| arr = arr / arr.max() |
| arr = (arr * 255).astype("uint8") |
| return arr |
|
|
| @staticmethod |
| def preprocess(x: NDArray) -> NDArray: |
| x = A.LongestMaxSize(max_size=512, p=1)(image=x)["image"] |
| x = A.PadIfNeeded(512, 512, border_mode=cv2.BORDER_CONSTANT, p=1)(image=x)[ |
| "image" |
| ] |
| return x |
|
|
| def forward( |
| self, x: torch.Tensor, female: torch.Tensor, return_logits: bool = False |
| ) -> torch.Tensor: |
| out = [] |
| for i in range(self.num_models): |
| model = getattr(self, f"net{i}") |
| out.append(model(x, female, return_logits)) |
| return torch.stack(out).mean(0) |
|
|