meharobaidullah's picture
Update app.py
dad5e01 verified
from __future__ import annotations
import asyncio
import io
import os
import threading
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Tuple
# Disable GPU usage before TensorFlow is imported.
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet import preprocess_input
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from huggingface_hub import hf_hub_download
from PIL import Image, UnidentifiedImageError
try:
tf.config.set_visible_devices([], "GPU")
except Exception:
pass
try:
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)
except Exception:
pass
MODEL_REPO = os.getenv("MODEL_REPO", "REPLACE_ME")
MODEL_FILENAME = os.getenv("MODEL_FILENAME", "my_model.h5")
MODEL_CACHE_PATH = Path("/tmp") / MODEL_FILENAME
MAX_FILE_SIZE_BYTES = 5 * 1024 * 1024
DEFAULT_IMAGE_SIZE = 224
LABELS = [
"Eczema",
"Viral Infections",
"Melanoma",
"Atopic Dermatitis",
"Basal Cell Carcinoma",
"Melanocytic Nevi",
"Keratosis-like Lesions",
"Psoriasis & Lichen Planus",
"Seborrheic Keratoses",
"Fungal Infections",
]
INDEX_TO_LABEL = {index: label for index, label in enumerate(LABELS)}
NO_CACHE_HEADERS = {
"Cache-Control": "no-store, no-cache, must-revalidate, max-age=0",
"Pragma": "no-cache",
"Expires": "0",
}
model: tf.keras.Model | None = None
image_height = DEFAULT_IMAGE_SIZE
image_width = DEFAULT_IMAGE_SIZE
model_lock = threading.Lock()
def _raise_http_error(status_code: int, detail: str) -> None:
raise HTTPException(status_code=status_code, detail=detail)
def _load_model_from_hugging_face() -> Tuple[tf.keras.Model, int, int]:
"""Download the model once and infer the expected image size."""
if MODEL_REPO == "REPLACE_ME":
_raise_http_error(
500,
"MODEL_REPO is not configured. Set MODEL_REPO to your Hugging Face repo id.",
)
downloaded_path = str(MODEL_CACHE_PATH)
if not MODEL_CACHE_PATH.exists():
downloaded_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
local_dir="/tmp",
local_dir_use_symlinks=False,
token=os.getenv("HF_TOKEN"),
)
loaded_model = tf.keras.models.load_model(downloaded_path)
input_shape = loaded_model.input_shape
if isinstance(input_shape, list):
_raise_http_error(500, "The model must expose a single image input tensor.")
if len(input_shape) != 4:
_raise_http_error(500, f"Expected a rank-4 image input, got {input_shape}.")
_, inferred_height, inferred_width, channels = input_shape
if channels not in (3, None):
_raise_http_error(500, f"Expected RGB input with 3 channels, got {input_shape}.")
height = int(inferred_height) if inferred_height is not None else DEFAULT_IMAGE_SIZE
width = int(inferred_width) if inferred_width is not None else DEFAULT_IMAGE_SIZE
output_shape = loaded_model.output_shape
if isinstance(output_shape, list):
_raise_http_error(500, "The model must expose a single output tensor.")
if output_shape[-1] != len(LABELS):
_raise_http_error(
500,
f"Model output size {output_shape[-1]} does not match the required label set of {len(LABELS)} classes.",
)
return loaded_model, height, width
def load_model_once() -> None:
global model, image_height, image_width
if model is not None:
return
with model_lock:
if model is not None:
return
loaded_model, height, width = _load_model_from_hugging_face()
model = loaded_model
image_height = height
image_width = width
def _no_cache_response(content: Dict[str, Any], status_code: int = 200) -> JSONResponse:
return JSONResponse(content=content, status_code=status_code, headers=NO_CACHE_HEADERS)
def _validate_image_bytes(file_bytes: bytes, file_name: str | None, content_type: str | None) -> bytes:
if not file_bytes:
_raise_http_error(400, "Uploaded file is empty.")
if len(file_bytes) > MAX_FILE_SIZE_BYTES:
_raise_http_error(413, "File size exceeds the 5MB limit.")
if content_type and not content_type.startswith("image/"):
_raise_http_error(415, f"Unsupported media type for {file_name or 'uploaded file'}.")
try:
with Image.open(io.BytesIO(file_bytes)) as image:
image.verify()
except (UnidentifiedImageError, OSError, ValueError):
_raise_http_error(415, f"Uploaded file {file_name or 'file'} is not a valid image.")
return file_bytes
# def _preprocess_image_bytes(file_bytes: bytes) -> np.ndarray:
# try:
# with Image.open(io.BytesIO(file_bytes)) as image:
# image = image.convert("RGB")
# image = image.resize((image_width, image_height))
# image_array = np.asarray(image, dtype=np.float32) / 255.0
# except (UnidentifiedImageError, OSError, ValueError) as exc:
# _raise_http_error(415, f"Invalid image data: {exc}")
# return image_array
def _preprocess_image_bytes(file_bytes: bytes) -> np.ndarray:
try:
with Image.open(io.BytesIO(file_bytes)) as image:
image = image.convert("RGB")
image = image.resize(
(image_width, image_height),
Image.Resampling.LANCZOS
)
image_array = np.asarray(image, dtype=np.float32)
# IMPORTANT
image_array = preprocess_input(image_array)
except (UnidentifiedImageError, OSError, ValueError) as exc:
_raise_http_error(415, f"Invalid image data: {exc}")
return image_array
def _predict_batch_numpy(batch: np.ndarray) -> np.ndarray:
if model is None:
_raise_http_error(500, "Model is not loaded.")
predictions = model.predict(batch, verbose=0)
return np.asarray(predictions, dtype=np.float32)
def _format_probabilities(probabilities: np.ndarray) -> Dict[str, float]:
return {
INDEX_TO_LABEL[index]: float(probabilities[index])
for index in range(len(LABELS))
}
@asynccontextmanager
async def lifespan(_: FastAPI):
load_model_once()
yield
app = FastAPI(title="Skin Disease Classification API", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://face-skin-disease-frontend.devfuze.workers.dev",
"http://localhost.*",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.post("/predict/single")
async def predict_single(file: UploadFile = File(...)) -> JSONResponse:
file_bytes = await file.read()
_validate_image_bytes(file_bytes, file.filename, file.content_type)
image_array = _preprocess_image_bytes(file_bytes)
batch = np.expand_dims(image_array, axis=0)
predictions = await asyncio.to_thread(_predict_batch_numpy, batch)
scores = predictions[0]
predicted_index = int(np.argmax(scores))
return _no_cache_response(
{
"predicted_class": INDEX_TO_LABEL[predicted_index],
"confidence": float(scores[predicted_index]),
"probabilities": _format_probabilities(scores),
}
)
@app.post("/predict/batch")
async def predict_batch(files: List[UploadFile] = File(...)) -> JSONResponse:
if not files:
_raise_http_error(400, "No images uploaded.")
images: List[np.ndarray] = []
for file in files:
file_bytes = await file.read()
_validate_image_bytes(file_bytes, file.filename, file.content_type)
images.append(_preprocess_image_bytes(file_bytes))
batch = np.stack(images, axis=0)
predictions = await asyncio.to_thread(_predict_batch_numpy, batch)
average_probabilities = np.mean(predictions, axis=0)
predicted_index = int(np.argmax(average_probabilities))
return _no_cache_response(
{
"num_images": len(files),
"final_prediction": INDEX_TO_LABEL[predicted_index],
"confidence": float(average_probabilities[predicted_index]),
"probabilities": _format_probabilities(average_probabilities),
}
)