Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import asyncio | |
| import io | |
| import os | |
| import threading | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| # Disable GPU usage before TensorFlow is imported. | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1") | |
| os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.applications.resnet import preprocess_input | |
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image, UnidentifiedImageError | |
| try: | |
| tf.config.set_visible_devices([], "GPU") | |
| except Exception: | |
| pass | |
| try: | |
| tf.config.threading.set_intra_op_parallelism_threads(1) | |
| tf.config.threading.set_inter_op_parallelism_threads(1) | |
| except Exception: | |
| pass | |
| MODEL_REPO = os.getenv("MODEL_REPO", "REPLACE_ME") | |
| MODEL_FILENAME = os.getenv("MODEL_FILENAME", "my_model.h5") | |
| MODEL_CACHE_PATH = Path("/tmp") / MODEL_FILENAME | |
| MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024 | |
| DEFAULT_IMAGE_SIZE = 224 | |
| LABELS = [ | |
| "Eczema", | |
| "Viral Infections", | |
| "Melanoma", | |
| "Atopic Dermatitis", | |
| "Basal Cell Carcinoma", | |
| "Melanocytic Nevi", | |
| "Keratosis-like Lesions", | |
| "Psoriasis & Lichen Planus", | |
| "Seborrheic Keratoses", | |
| "Fungal Infections", | |
| ] | |
| INDEX_TO_LABEL = {index: label for index, label in enumerate(LABELS)} | |
| NO_CACHE_HEADERS = { | |
| "Cache-Control": "no-store, no-cache, must-revalidate, max-age=0", | |
| "Pragma": "no-cache", | |
| "Expires": "0", | |
| } | |
| model: tf.keras.Model | None = None | |
| image_height = DEFAULT_IMAGE_SIZE | |
| image_width = DEFAULT_IMAGE_SIZE | |
| model_lock = threading.Lock() | |
| def _raise_http_error(status_code: int, detail: str) -> None: | |
| raise HTTPException(status_code=status_code, detail=detail) | |
| def _load_model_from_hugging_face() -> Tuple[tf.keras.Model, int, int]: | |
| """Download the model once and infer the expected image size.""" | |
| if MODEL_REPO == "REPLACE_ME": | |
| _raise_http_error( | |
| 500, | |
| "MODEL_REPO is not configured. Set MODEL_REPO to your Hugging Face repo id.", | |
| ) | |
| downloaded_path = str(MODEL_CACHE_PATH) | |
| if not MODEL_CACHE_PATH.exists(): | |
| downloaded_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILENAME, | |
| local_dir="/tmp", | |
| local_dir_use_symlinks=False, | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| loaded_model = tf.keras.models.load_model(downloaded_path) | |
| input_shape = loaded_model.input_shape | |
| if isinstance(input_shape, list): | |
| _raise_http_error(500, "The model must expose a single image input tensor.") | |
| if len(input_shape) != 4: | |
| _raise_http_error(500, f"Expected a rank-4 image input, got {input_shape}.") | |
| _, inferred_height, inferred_width, channels = input_shape | |
| if channels not in (3, None): | |
| _raise_http_error(500, f"Expected RGB input with 3 channels, got {input_shape}.") | |
| height = int(inferred_height) if inferred_height is not None else DEFAULT_IMAGE_SIZE | |
| width = int(inferred_width) if inferred_width is not None else DEFAULT_IMAGE_SIZE | |
| output_shape = loaded_model.output_shape | |
| if isinstance(output_shape, list): | |
| _raise_http_error(500, "The model must expose a single output tensor.") | |
| if output_shape[-1] != len(LABELS): | |
| _raise_http_error( | |
| 500, | |
| f"Model output size {output_shape[-1]} does not match the required label set of {len(LABELS)} classes.", | |
| ) | |
| return loaded_model, height, width | |
| def load_model_once() -> None: | |
| global model, image_height, image_width | |
| if model is not None: | |
| return | |
| with model_lock: | |
| if model is not None: | |
| return | |
| loaded_model, height, width = _load_model_from_hugging_face() | |
| model = loaded_model | |
| image_height = height | |
| image_width = width | |
| def _no_cache_response(content: Dict[str, Any], status_code: int = 200) -> JSONResponse: | |
| return JSONResponse(content=content, status_code=status_code, headers=NO_CACHE_HEADERS) | |
| def _validate_image_bytes(file_bytes: bytes, file_name: str | None, content_type: str | None) -> bytes: | |
| if not file_bytes: | |
| _raise_http_error(400, "Uploaded file is empty.") | |
| if len(file_bytes) > MAX_FILE_SIZE_BYTES: | |
| _raise_http_error(413, "File size exceeds the 5MB limit.") | |
| if content_type and not content_type.startswith("image/"): | |
| _raise_http_error(415, f"Unsupported media type for {file_name or 'uploaded file'}.") | |
| try: | |
| with Image.open(io.BytesIO(file_bytes)) as image: | |
| image.verify() | |
| except (UnidentifiedImageError, OSError, ValueError): | |
| _raise_http_error(415, f"Uploaded file {file_name or 'file'} is not a valid image.") | |
| return file_bytes | |
| # def _preprocess_image_bytes(file_bytes: bytes) -> np.ndarray: | |
| # try: | |
| # with Image.open(io.BytesIO(file_bytes)) as image: | |
| # image = image.convert("RGB") | |
| # image = image.resize((image_width, image_height)) | |
| # image_array = np.asarray(image, dtype=np.float32) / 255.0 | |
| # except (UnidentifiedImageError, OSError, ValueError) as exc: | |
| # _raise_http_error(415, f"Invalid image data: {exc}") | |
| # return image_array | |
| def _preprocess_image_bytes(file_bytes: bytes) -> np.ndarray: | |
| try: | |
| with Image.open(io.BytesIO(file_bytes)) as image: | |
| image = image.convert("RGB") | |
| image = image.resize( | |
| (image_width, image_height), | |
| Image.Resampling.LANCZOS | |
| ) | |
| image_array = np.asarray(image, dtype=np.float32) | |
| # IMPORTANT | |
| image_array = preprocess_input(image_array) | |
| except (UnidentifiedImageError, OSError, ValueError) as exc: | |
| _raise_http_error(415, f"Invalid image data: {exc}") | |
| return image_array | |
| def _predict_batch_numpy(batch: np.ndarray) -> np.ndarray: | |
| if model is None: | |
| _raise_http_error(500, "Model is not loaded.") | |
| predictions = model.predict(batch, verbose=0) | |
| return np.asarray(predictions, dtype=np.float32) | |
| def _format_probabilities(probabilities: np.ndarray) -> Dict[str, float]: | |
| return { | |
| INDEX_TO_LABEL[index]: float(probabilities[index]) | |
| for index in range(len(LABELS)) | |
| } | |
| async def lifespan(_: FastAPI): | |
| load_model_once() | |
| yield | |
| app = FastAPI(title="Skin Disease Classification API", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "https://face-skin-disease-frontend.devfuze.workers.dev", | |
| "http://localhost.*", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def health() -> Dict[str, str]: | |
| return {"status": "ok"} | |
| async def predict_single(file: UploadFile = File(...)) -> JSONResponse: | |
| file_bytes = await file.read() | |
| _validate_image_bytes(file_bytes, file.filename, file.content_type) | |
| image_array = _preprocess_image_bytes(file_bytes) | |
| batch = np.expand_dims(image_array, axis=0) | |
| predictions = await asyncio.to_thread(_predict_batch_numpy, batch) | |
| scores = predictions[0] | |
| predicted_index = int(np.argmax(scores)) | |
| return _no_cache_response( | |
| { | |
| "predicted_class": INDEX_TO_LABEL[predicted_index], | |
| "confidence": float(scores[predicted_index]), | |
| "probabilities": _format_probabilities(scores), | |
| } | |
| ) | |
| async def predict_batch(files: List[UploadFile] = File(...)) -> JSONResponse: | |
| if not files: | |
| _raise_http_error(400, "No images uploaded.") | |
| images: List[np.ndarray] = [] | |
| for file in files: | |
| file_bytes = await file.read() | |
| _validate_image_bytes(file_bytes, file.filename, file.content_type) | |
| images.append(_preprocess_image_bytes(file_bytes)) | |
| batch = np.stack(images, axis=0) | |
| predictions = await asyncio.to_thread(_predict_batch_numpy, batch) | |
| average_probabilities = np.mean(predictions, axis=0) | |
| predicted_index = int(np.argmax(average_probabilities)) | |
| return _no_cache_response( | |
| { | |
| "num_images": len(files), | |
| "final_prediction": INDEX_TO_LABEL[predicted_index], | |
| "confidence": float(average_probabilities[predicted_index]), | |
| "probabilities": _format_probabilities(average_probabilities), | |
| } | |
| ) | |