| import gradio as gr
|
| import tensorflow as tf
|
| import numpy as np
|
| from PIL import Image
|
|
|
|
|
| MODEL_PATH = "./model/best_model.h5"
|
| LABELS = ["class1", "class2", "class3", "class4"]
|
| IMG_SIZE = 512
|
|
|
|
|
| model = tf.keras.models.load_model(MODEL_PATH)
|
|
|
|
|
| def preprocess(image: Image.Image):
|
| img = image.resize((IMG_SIZE, IMG_SIZE)).convert("RGB")
|
| arr = np.array(img)
|
| arr = np.expand_dims(arr, 0)
|
| return tf.keras.applications.mobilenet_v3.preprocess_input(arr)
|
|
|
|
|
| def predict(image):
|
| """
|
| Gradio accepts PIL image. Returns (label, confidence).
|
| """
|
| arr = preprocess(image)
|
| preds = model.predict(arr)[0]
|
| idx = int(np.argmax(preds))
|
| label = LABELS[idx] if idx < len(LABELS) else "Unknown"
|
| confidence = float(preds[idx])
|
| return {label: confidence}
|
|
|
|
|
| iface = gr.Interface(
|
| fn=predict,
|
| inputs=gr.Image(type="pil"),
|
| outputs=gr.Label(num_top_classes=1),
|
| title="My TF Classifier",
|
| description="Upload an image and get back its class and confidence."
|
| )
|
|
|
| if __name__ == "__main__":
|
| iface.launch()
|
|
|