| """ResNet inference service implementation.""" |
|
|
| import base64 |
| import os |
| from io import BytesIO |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers import AutoImageProcessor, ResNetForImageClassification |
|
|
| from app.core.logging import logger |
| from app.services.base import InferenceService |
| from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse |
|
|
|
|
| class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]): |
| """ResNet-18 inference service for image classification.""" |
|
|
| def __init__(self, model_name: str = "microsoft/resnet-18"): |
| self.model_name = model_name |
| self.model = None |
| self.processor = None |
| self._is_loaded = False |
| self.model_path = os.path.join("models", model_name) |
| logger.info(f"Initializing ResNet service: {self.model_path}") |
|
|
| def load_model(self) -> None: |
| if self._is_loaded: |
| return |
|
|
| if not os.path.exists(self.model_path): |
| raise FileNotFoundError(f"Model not found: {self.model_path}") |
|
|
| config_path = os.path.join(self.model_path, "config.json") |
| if not os.path.exists(config_path): |
| raise FileNotFoundError(f"Config not found: {config_path}") |
|
|
| logger.info(f"Loading model from {self.model_path}") |
|
|
| import warnings |
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", category=FutureWarning) |
| self.processor = AutoImageProcessor.from_pretrained( |
| self.model_path, local_files_only=True |
| ) |
| self.model = ResNetForImageClassification.from_pretrained( |
| self.model_path, local_files_only=True |
| ) |
| assert self.model is not None |
|
|
| self._is_loaded = True |
| logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") |
|
|
| def predict(self, request: ImageRequest) -> PredictionResponse: |
| if not self.is_loaded: |
| raise RuntimeError("model is not loaded") |
| assert self.processor is not None |
| assert self.model is not None |
|
|
| image_data = base64.b64decode(request.image.data) |
| image = Image.open(BytesIO(image_data)) |
|
|
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
|
|
| inputs = self.processor(image, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| logits = self.model(**inputs).logits.squeeze() |
|
|
| |
| |
| |
| logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist() |
|
|
| |
| x = image.width // 3 |
| y = image.height // 3 |
| |
| mask = np.zeros((image.height, image.width), dtype=np.uint8) |
| mask[y:(2*y), x:(2*x)] = 1 |
| mask_obj = BinaryMask.from_numpy(mask) |
|
|
| return PredictionResponse( |
| logprobs=logprobs, |
| localizationMask=mask_obj, |
| ) |
|
|
| @property |
| def is_loaded(self) -> bool: |
| return self._is_loaded |
|
|