| """Medical Image Classification model wrapper class that loads the model, preprocesses inputs and performs inference.""" |
|
|
| import torch |
| from PIL import Image |
| import pandas as pd |
| from typing import List, Tuple |
| import os |
| import tempfile |
| import base64 |
| import io |
|
|
| from MedImageInsight.UniCLModel import build_unicl_model |
| from MedImageInsight.Utils.Arguments import load_opt_from_config_files |
| from MedImageInsight.ImageDataLoader import build_transforms |
| from MedImageInsight.LangEncoder import build_tokenizer |
|
|
|
|
| class MedImageInsight: |
| """Wrapper class for medical image classification model.""" |
|
|
| def __init__( |
| self, |
| model_dir: str, |
| vision_model_name: str, |
| language_model_name: str |
| ) -> None: |
| """Initialize the medical image classifier. |
| |
| Args: |
| model_dir: Directory containing model files and config |
| vision_model_name: Name of the vision model |
| language_model_name: Name of the language model |
| """ |
| self.model_dir = model_dir |
| self.vision_model_name = vision_model_name |
| self.language_model_name = language_model_name |
| self.model = None |
| self.device = None |
| self.tokenize = None |
| self.preprocess = None |
| self.opt = None |
|
|
| def load_model(self) -> None: |
| """Load the model and necessary components.""" |
| try: |
| |
| config_path = os.path.join(self.model_dir, 'config.yaml') |
| self.opt = load_opt_from_config_files([config_path]) |
|
|
| |
| self.opt['LANG_ENCODER']['PRETRAINED_TOKENIZER'] = os.path.join( |
| self.model_dir, |
| 'language_model', |
| 'clip_tokenizer_4.16.2' |
| ) |
| self.opt['UNICL_MODEL']['PRETRAINED'] = os.path.join( |
| self.model_dir, |
| 'vision_model', |
| self.vision_model_name |
| ) |
|
|
| |
| self.preprocess = build_transforms(self.opt, False) |
| self.model = build_unicl_model(self.opt) |
|
|
| |
| if torch.backends.mps.is_available(): |
| self.device = torch.device("mps") |
| print("Using MPS backend for model execution.") |
| else: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if self.device.type == "cuda": |
| print("Using CUDA backend for model execution.") |
| else: |
| print("Using CPU for model execution.") |
|
|
| self.model.to(self.device) |
|
|
| |
| self.tokenize = build_tokenizer(self.opt['LANG_ENCODER']) |
| self.max_length = self.opt['LANG_ENCODER']['CONTEXT_LENGTH'] |
|
|
| print(f"Model loaded successfully on device: {self.device}") |
|
|
| except Exception as e: |
| print("Failed to load the model:") |
| raise e |
|
|
| @staticmethod |
| def decode_base64_image(base64_str: str) -> Image.Image: |
| """Decode base64 string to PIL Image and ensure RGB format. |
| |
| Args: |
| base64_str: Base64 encoded image string |
| |
| Returns: |
| PIL Image object in RGB format |
| """ |
| try: |
| |
| if ',' in base64_str: |
| base64_str = base64_str.split(',')[1] |
|
|
| image_bytes = base64.b64decode(base64_str) |
| image = Image.open(io.BytesIO(image_bytes)) |
|
|
| |
| if image.mode in ('L', 'LA'): |
| image = image.convert('RGB') |
|
|
| return image |
| except Exception as e: |
| raise ValueError(f"Failed to decode base64 image: {str(e)}") |
|
|
| def predict(self, images: List[str], labels: List[str], multilabel: bool = False) -> List[dict]: |
| """Perform zero shot classification on the input images. |
| |
| Args: |
| images: List of base64 encoded image strings |
| labels: List of candidate labels for classification |
| |
| Returns: |
| DataFrame with columns ["probabilities", "labels"] |
| """ |
| if not self.model: |
| raise RuntimeError("Model not loaded. Call load_model() first.") |
|
|
| if not labels: |
| raise ValueError("No labels provided") |
|
|
| |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| |
| image_list = [] |
| for img_base64 in images: |
| try: |
| img = self.decode_base64_image(img_base64) |
| image_list.append(img) |
| except Exception as e: |
| raise ValueError(f"Failed to process image: {str(e)}") |
|
|
| |
| probs = self.run_inference_batch(image_list, labels, multilabel) |
| probs_np = probs.cpu().numpy() |
| results = [] |
| for prob_row in probs_np: |
| |
| label_probs = [(label, float(prob)) for label, prob in zip(labels, prob_row)] |
| label_probs.sort(key=lambda x: x[1], reverse=True) |
|
|
| |
| results.append({ |
| label: prob |
| for label, prob in label_probs |
| }) |
|
|
| return results |
|
|
| def encode(self, images: List[str] = None, texts: List[str] = None): |
|
|
| output = { |
| "image_embeddings" : None, |
| "text_embeddings" : None, |
| } |
|
|
| if not self.model: |
| raise RuntimeError("Model not loaded. Call load_model() first.") |
|
|
| if not images and not texts: |
| raise ValueError("You must provide either images or texts") |
|
|
| if images is not None: |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| |
| image_list = [] |
| for img_base64 in images: |
| try: |
| img = self.decode_base64_image(img_base64) |
| image_list.append(img) |
| except Exception as e: |
| raise ValueError(f"Failed to process image: {str(e)}") |
| images = torch.stack([self.preprocess(img) for img in image_list]).to(self.device) |
| with torch.no_grad(): |
| output["image_embeddings"] = self.model.encode_image(images).cpu().numpy() |
|
|
| if texts is not None: |
| text_tokens = self.tokenize( |
| texts, |
| padding='max_length', |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors='pt' |
| ) |
|
|
| |
| text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()} |
| output["text_embeddings"] = self.model.encode_text(text_tokens).cpu().numpy() |
|
|
|
|
| return output |
|
|
| def run_inference_batch( |
| self, |
| images: List[Image.Image], |
| texts: List[str], |
| multilabel: bool = False |
| ) -> torch.Tensor: |
| """Perform inference on batch of input images. |
| |
| Args: |
| images: List of PIL Image objects |
| texts: List of text labels |
| multilabel: If True, use sigmoid for multilabel classification. |
| If False, use softmax for single-label classification. |
| |
| Returns: |
| Tensor of prediction probabilities |
| """ |
| |
| images = torch.stack([self.preprocess(img) for img in images]).to(self.device) |
|
|
| |
| text_tokens = self.tokenize( |
| texts, |
| padding='max_length', |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors='pt' |
| ) |
|
|
| |
| text_tokens = {k: v.to(self.device) for k, v in text_tokens.items()} |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model(image=images, text=text_tokens) |
| logits_per_image = outputs[0] @ outputs[1].t() * outputs[2] |
|
|
| if multilabel: |
| |
| probs = torch.sigmoid(logits_per_image) |
| else: |
| |
| probs = logits_per_image.softmax(dim=1) |
|
|
| return probs |
|
|
|
|
|
|