| """Model and dataset loading, inference, and label extraction functions.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| from functools import lru_cache |
| from typing import Any, Dict, Optional |
|
|
| import numpy as np |
| import torch |
| from datasets import DatasetDict, load_dataset |
| from PIL import Image |
| from torchvision import transforms |
| from torchvision.transforms import functional as TF |
| from transformers import ( |
| AutoImageProcessor, |
| AutoModelForImageClassification, |
| ) |
|
|
| HF_REPO_ID = "raidium/curia" |
| HF_DATASET_ID = "raidium/CuriaBench" |
|
|
|
|
| class _NumpyToTensor: |
| """Convert numpy arrays to tensors while preserving tensors/images.""" |
|
|
| def __call__(self, value: Any) -> torch.Tensor: |
| if isinstance(value, (torch.Tensor, Image.Image)): |
| return value |
| return torch.tensor(value).unsqueeze(0) |
|
|
|
|
| class AdaptativeResizeMask(torch.nn.Module): |
| """Resize binary masks with a fallback threshold to avoid empty masks.""" |
|
|
| def __init__(self, target_size: int = 512, initial_threshold: float = 0.5) -> None: |
| super().__init__() |
| self.target_size = target_size |
| self.initial_threshold = initial_threshold |
|
|
| def forward(self, mask: torch.Tensor) -> torch.Tensor: |
| mask = mask.to(dtype=torch.float32) |
| resized = TF.resize( |
| mask, |
| (self.target_size, self.target_size), |
| interpolation=TF.InterpolationMode.BILINEAR, |
| antialias=True, |
| ) |
| binary = resized > self.initial_threshold |
| if binary.sum() == 0: |
| new_threshold = torch.max(resized) * 0.5 |
| binary = resized > new_threshold |
| return binary.to(dtype=torch.float32) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def make_mask_transform(crop_size: int = 512) -> transforms.Compose: |
| """Return the resize transform used during training/inference.""" |
|
|
| return transforms.Compose( |
| [ |
| _NumpyToTensor(), |
| AdaptativeResizeMask(target_size=crop_size), |
| ] |
| ) |
|
|
|
|
| def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]: |
| """Apply Curia's mask preprocessing so heads get the ROI they expect.""" |
|
|
| if mask is None: |
| return None |
|
|
| mask_transform = make_mask_transform() |
|
|
| try: |
| mask_arr = np.array(mask) |
| except Exception: |
| return None |
|
|
| if mask_arr.size == 0: |
| return None |
|
|
| if mask_arr.ndim == 3: |
| tensor = mask_transform(mask_arr.transpose(2, 0, 1)) |
| tensor = tensor.transpose(1, 3).transpose(1, 2) |
| else: |
| tensor = mask_transform(torch.tensor([mask_arr])) |
| tensor = tensor.unsqueeze(0) |
|
|
| if isinstance(tensor, np.ndarray): |
| tensor = torch.from_numpy(tensor) |
|
|
| return tensor |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_id_to_labels() -> Dict[str, Dict[str, str]]: |
| """Load the id_to_labels.json mapping file.""" |
| json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json") |
| with open(json_path, "r") as f: |
| data = json.load(f) |
| |
| for head in data: |
| data[head] = {int(k): v for k, v in data[head].items()} |
| return data |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_processor() -> AutoImageProcessor: |
| token = os.environ.get("HF_TOKEN") |
| return AutoImageProcessor.from_pretrained( |
| HF_REPO_ID, trust_remote_code=True, token=token |
| ) |
|
|
|
|
| @lru_cache(maxsize=None) |
| def load_model(head: str) -> AutoModelForImageClassification: |
| token = os.environ.get("HF_TOKEN") |
| model = AutoModelForImageClassification.from_pretrained( |
| HF_REPO_ID, |
| trust_remote_code=True, |
| subfolder=head, |
| token=token, |
| ) |
| model.eval() |
| return model |
|
|
|
|
| @lru_cache(maxsize=None) |
| def load_curia_dataset(subset: str) -> Any: |
| token = os.environ.get("HF_TOKEN") |
| ds = load_dataset( |
| HF_DATASET_ID, |
| subset, |
| split="test", |
| token=token, |
| ) |
| if isinstance(ds, DatasetDict): |
| return ds["test"] |
| return ds |
|
|
| def infer_image( |
| image: np.ndarray, |
| head: str, |
| mask: Any | None = None, |
| return_probs: bool = True, |
| ) -> torch.Tensor: |
| processor = load_processor() |
| model = load_model(head) |
| with torch.no_grad(): |
| processed = processor(images=image, return_tensors="pt") |
| mask_tensor = prepare_mask_for_model(mask) |
| if mask_tensor is not None: |
| processed["mask"] = mask_tensor |
| outputs = model(**processed) |
| logits = outputs["logits"] |
| if return_probs: |
| probs = torch.nn.functional.softmax(logits[0], dim=-1) |
| return probs |
| else: |
| return logits[0].squeeze() |
|
|