tftest01 / app.py
asdf98's picture
Upload 7 files
3a93dfe verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
# --- Config ---
MODEL_PATH = "./model/best_model.h5"
LABELS = ["class1", "class2", "class3", "class4"] # ← update to your labels
IMG_SIZE = 512
# --- Load model ---
model = tf.keras.models.load_model(MODEL_PATH)
# --- Preprocessing ---
def preprocess(image: Image.Image):
img = image.resize((IMG_SIZE, IMG_SIZE)).convert("RGB")
arr = np.array(img)
arr = np.expand_dims(arr, 0)
return tf.keras.applications.mobilenet_v3.preprocess_input(arr)
# --- Prediction function ---
def predict(image):
"""
Gradio accepts PIL image. Returns (label, confidence).
"""
arr = preprocess(image)
preds = model.predict(arr)[0]
idx = int(np.argmax(preds))
label = LABELS[idx] if idx < len(LABELS) else "Unknown"
confidence = float(preds[idx])
return {label: confidence}
# --- Launch Gradio Interface ---
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=1),
title="My TF Classifier",
description="Upload an image and get back its class and confidence."
)
if __name__ == "__main__":
iface.launch()