| """ |
| ConvNeXt Classifier Tool - Skin lesion classification using ConvNeXt + MONET features |
| Loads seed42_fold0.pt checkpoint and performs classification. |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from PIL import Image |
| from torchvision import transforms |
| from typing import Optional, Dict, List, Tuple |
| import timm |
|
|
|
|
| |
| CLASS_NAMES = [ |
| 'AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF', |
| 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC' |
| ] |
|
|
| CLASS_FULL_NAMES = { |
| 'AKIEC': 'Actinic Keratosis / Intraepithelial Carcinoma', |
| 'BCC': 'Basal Cell Carcinoma', |
| 'BEN_OTH': 'Benign Other', |
| 'BKL': 'Benign Keratosis-like Lesion', |
| 'DF': 'Dermatofibroma', |
| 'INF': 'Inflammatory', |
| 'MAL_OTH': 'Malignant Other', |
| 'MEL': 'Melanoma', |
| 'NV': 'Melanocytic Nevus', |
| 'SCCKA': 'Squamous Cell Carcinoma / Keratoacanthoma', |
| 'VASC': 'Vascular Lesion' |
| } |
|
|
|
|
| class ConvNeXtDualEncoder(nn.Module): |
| """ |
| Dual-image ConvNeXt model matching the trained checkpoint. |
| Processes BOTH clinical and dermoscopy images through shared backbone. |
| |
| Metadata input: 19 dimensions |
| - age (1): normalized age |
| - sex (4): one-hot encoded |
| - site (7): one-hot encoded (reduced from 14) |
| - MONET (7): 7 MONET feature scores |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = 'convnext_base.fb_in22k_ft_in1k', |
| metadata_dim: int = 19, |
| num_classes: int = 11, |
| dropout: float = 0.3 |
| ): |
| super().__init__() |
|
|
| self.backbone = timm.create_model( |
| model_name, |
| pretrained=False, |
| num_classes=0 |
| ) |
| backbone_dim = self.backbone.num_features |
|
|
| |
| self.meta_mlp = nn.Sequential( |
| nn.Linear(metadata_dim, 64), |
| nn.LayerNorm(64), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
|
|
| |
| |
| fusion_dim = backbone_dim * 2 + 64 |
| self.classifier = nn.Sequential( |
| nn.Linear(fusion_dim, 512), |
| nn.LayerNorm(512), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(512, 256), |
| nn.LayerNorm(256), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(256, num_classes) |
| ) |
|
|
| self.metadata_dim = metadata_dim |
| self.num_classes = num_classes |
| self.backbone_dim = backbone_dim |
|
|
| def forward( |
| self, |
| clinical_img: torch.Tensor, |
| derm_img: Optional[torch.Tensor] = None, |
| metadata: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Forward pass with dual images. |
| |
| Args: |
| clinical_img: [B, 3, H, W] clinical image tensor |
| derm_img: [B, 3, H, W] dermoscopy image tensor (uses clinical if None) |
| metadata: [B, 19] metadata tensor (zeros if None) |
| |
| Returns: |
| logits: [B, 11] |
| """ |
| |
| clinical_features = self.backbone(clinical_img) |
|
|
| |
| if derm_img is not None: |
| derm_features = self.backbone(derm_img) |
| else: |
| derm_features = clinical_features |
|
|
| |
| if metadata is not None: |
| meta_features = self.meta_mlp(metadata) |
| else: |
| batch_size = clinical_features.size(0) |
| meta_features = torch.zeros( |
| batch_size, 64, |
| device=clinical_features.device |
| ) |
|
|
| |
| fused = torch.cat([clinical_features, derm_features, meta_features], dim=1) |
| logits = self.classifier(fused) |
|
|
| return logits |
|
|
|
|
| class ConvNeXtClassifier: |
| """ |
| ConvNeXt classifier tool for skin lesion classification. |
| Uses dual images (clinical + dermoscopy) and MONET features. |
| """ |
|
|
| |
| SITE_MAPPING = { |
| 'head': 0, 'neck': 0, 'face': 0, |
| 'trunk': 1, 'back': 1, 'chest': 1, 'abdomen': 1, |
| 'upper': 2, 'arm': 2, 'hand': 2, |
| 'lower': 3, 'leg': 3, 'foot': 3, 'thigh': 3, |
| 'genital': 4, 'oral': 5, 'acral': 6, |
| } |
|
|
| SEX_MAPPING = {'male': 0, 'female': 1, 'other': 2, 'unknown': 3} |
|
|
| def __init__( |
| self, |
| checkpoint_path: str = "models/seed42_fold0.pt", |
| device: Optional[str] = None |
| ): |
| self.checkpoint_path = checkpoint_path |
| self.device = device |
| self.model = None |
| self.loaded = False |
|
|
| |
| self.transform = transforms.Compose([ |
| transforms.Resize((384, 384)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| ]) |
|
|
| def load(self): |
| """Load the ConvNeXt model from checkpoint""" |
| if self.loaded: |
| return |
|
|
| |
| forced = os.environ.get("SKINPRO_TOOL_DEVICE") |
| if forced: |
| self.device = forced |
| elif self.device is None: |
| if torch.cuda.is_available(): |
| self.device = "cuda" |
| elif torch.backends.mps.is_available(): |
| self.device = "mps" |
| else: |
| self.device = "cpu" |
|
|
| |
| self.model = ConvNeXtDualEncoder( |
| model_name='convnext_base.fb_in22k_ft_in1k', |
| metadata_dim=19, |
| num_classes=11, |
| dropout=0.3 |
| ) |
|
|
| |
| checkpoint = torch.load( |
| self.checkpoint_path, |
| map_location=self.device, |
| weights_only=False |
| ) |
|
|
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| self.model.load_state_dict(checkpoint) |
|
|
| self.model.to(self.device) |
| self.model.eval() |
| self.loaded = True |
|
|
| def encode_metadata( |
| self, |
| age: Optional[float] = None, |
| sex: Optional[str] = None, |
| site: Optional[str] = None, |
| monet_scores: Optional[List[float]] = None |
| ) -> torch.Tensor: |
| """ |
| Encode metadata into 19-dim vector. |
| |
| Layout: [age(1), sex(4), site(7), monet(7)] = 19 |
| |
| Args: |
| age: Patient age in years |
| sex: 'male', 'female', 'other', or None |
| site: Anatomical site string |
| monet_scores: List of 7 MONET feature scores |
| |
| Returns: |
| torch.Tensor of shape [19] |
| """ |
| features = [] |
|
|
| |
| age_norm = (age - 50) / 30 if age is not None else 0.0 |
| features.append(age_norm) |
|
|
| |
| sex_onehot = [0.0] * 4 |
| if sex: |
| sex_idx = self.SEX_MAPPING.get(sex.lower(), 3) |
| sex_onehot[sex_idx] = 1.0 |
| features.extend(sex_onehot) |
|
|
| |
| site_onehot = [0.0] * 7 |
| if site: |
| site_lower = site.lower() |
| for key, idx in self.SITE_MAPPING.items(): |
| if key in site_lower: |
| site_onehot[idx] = 1.0 |
| break |
| features.extend(site_onehot) |
|
|
| |
| if monet_scores is not None and len(monet_scores) == 7: |
| features.extend(monet_scores) |
| else: |
| features.extend([0.0] * 7) |
|
|
| return torch.tensor(features, dtype=torch.float32) |
|
|
| def preprocess_image(self, image: Image.Image) -> torch.Tensor: |
| """Preprocess PIL image for model input""" |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
| return self.transform(image).unsqueeze(0) |
|
|
| def classify( |
| self, |
| clinical_image: Image.Image, |
| derm_image: Optional[Image.Image] = None, |
| age: Optional[float] = None, |
| sex: Optional[str] = None, |
| site: Optional[str] = None, |
| monet_scores: Optional[List[float]] = None, |
| top_k: int = 5 |
| ) -> Dict: |
| """ |
| Classify a skin lesion. |
| |
| Args: |
| clinical_image: Clinical (close-up) image |
| derm_image: Dermoscopy image (optional, uses clinical if None) |
| age: Patient age |
| sex: Patient sex |
| site: Anatomical site |
| monet_scores: 7 MONET feature scores |
| top_k: Number of top predictions to return |
| |
| Returns: |
| dict with 'predictions', 'probabilities', 'top_class', 'confidence' |
| """ |
| if not self.loaded: |
| self.load() |
|
|
| |
| clinical_tensor = self.preprocess_image(clinical_image).to(self.device) |
|
|
| if derm_image is not None: |
| derm_tensor = self.preprocess_image(derm_image).to(self.device) |
| else: |
| derm_tensor = None |
|
|
| |
| metadata = self.encode_metadata(age, sex, site, monet_scores) |
| metadata_tensor = metadata.unsqueeze(0).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| logits = self.model(clinical_tensor, derm_tensor, metadata_tensor) |
| probs = torch.softmax(logits, dim=1)[0].cpu().numpy() |
|
|
| |
| top_indices = np.argsort(probs)[::-1][:top_k] |
|
|
| predictions = [] |
| for idx in top_indices: |
| predictions.append({ |
| 'class': CLASS_NAMES[idx], |
| 'full_name': CLASS_FULL_NAMES[CLASS_NAMES[idx]], |
| 'probability': float(probs[idx]) |
| }) |
|
|
| return { |
| 'predictions': predictions, |
| 'probabilities': probs.tolist(), |
| 'top_class': CLASS_NAMES[top_indices[0]], |
| 'confidence': float(probs[top_indices[0]]), |
| 'all_classes': CLASS_NAMES, |
| } |
|
|
| def __call__( |
| self, |
| clinical_image: Image.Image, |
| derm_image: Optional[Image.Image] = None, |
| **kwargs |
| ) -> Dict: |
| """Shorthand for classify()""" |
| return self.classify(clinical_image, derm_image, **kwargs) |
|
|
|
|
| |
| _convnext_instance = None |
|
|
|
|
| def get_convnext_classifier(checkpoint_path: str = "models/seed42_fold0.pt") -> ConvNeXtClassifier: |
| """Get or create ConvNeXt classifier instance""" |
| global _convnext_instance |
| if _convnext_instance is None: |
| _convnext_instance = ConvNeXtClassifier(checkpoint_path) |
| return _convnext_instance |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| print("ConvNeXt Classifier Test") |
| print("=" * 50) |
|
|
| classifier = ConvNeXtClassifier() |
| print("Loading model...") |
| classifier.load() |
| print("Model loaded!") |
|
|
| if len(sys.argv) > 1: |
| image_path = sys.argv[1] |
| print(f"\nClassifying: {image_path}") |
|
|
| image = Image.open(image_path).convert("RGB") |
|
|
| |
| monet_scores = [0.2, 0.1, 0.05, 0.3, 0.7, 0.1, 0.05] |
|
|
| result = classifier.classify( |
| clinical_image=image, |
| age=55, |
| sex="male", |
| site="back", |
| monet_scores=monet_scores |
| ) |
|
|
| print("\nTop Predictions:") |
| for pred in result['predictions']: |
| print(f" {pred['probability']:.1%} - {pred['class']} ({pred['full_name']})") |
|
|