Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +126 -86
src/streamlit_app.py
CHANGED
|
@@ -1,121 +1,161 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import numpy as np
|
| 4 |
import streamlit as st
|
| 5 |
from PIL import Image
|
| 6 |
-
|
| 7 |
import tensorflow as tf
|
| 8 |
-
from tensorflow.keras.applications import ResNet50
|
| 9 |
-
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
|
| 10 |
-
from tensorflow.keras.models import Sequential
|
| 11 |
|
| 12 |
-
#
|
| 13 |
# Page config
|
| 14 |
-
#
|
| 15 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
st.title("π§ ResNet50 Image Predictor")
|
| 17 |
-
st.write("
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
# Paths (
|
| 21 |
-
#
|
|
|
|
| 22 |
WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5"
|
| 23 |
CLASS_NAMES_PATH = "src/class_names3.json"
|
| 24 |
|
| 25 |
-
|
| 26 |
-
# Debug info
|
| 27 |
-
# -----------------------------
|
| 28 |
-
with st.expander("π Debug info (HuggingFace check)"):
|
| 29 |
-
try:
|
| 30 |
-
st.write("Files in src/:")
|
| 31 |
-
st.write(os.listdir("src"))
|
| 32 |
-
st.write("Weights exists:", os.path.exists(WEIGHTS_PATH))
|
| 33 |
-
st.write("Class names exists:", os.path.exists(CLASS_NAMES_PATH))
|
| 34 |
-
st.write("TF:", tf.__version__)
|
| 35 |
-
except Exception as e:
|
| 36 |
-
st.error(e)
|
| 37 |
|
| 38 |
-
#
|
| 39 |
-
#
|
| 40 |
-
#
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
with open(path, "r") as f:
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
Dropout(0.5),
|
| 60 |
-
Dense(num_classes, activation="softmax"),
|
| 61 |
-
])
|
| 62 |
-
return model
|
| 63 |
|
| 64 |
-
# -----------------------------
|
| 65 |
-
# Load model + weights
|
| 66 |
-
# -----------------------------
|
| 67 |
@st.cache_resource
|
| 68 |
-
def
|
|
|
|
|
|
|
| 69 |
if not os.path.exists(weights_path):
|
| 70 |
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
| 71 |
-
if not os.path.exists(class_names_path):
|
| 72 |
-
raise FileNotFoundError(f"Missing class names file: {class_names_path}")
|
| 73 |
|
| 74 |
-
|
| 75 |
-
model =
|
| 76 |
|
| 77 |
-
#
|
| 78 |
-
_ = model(np.zeros((1, 224, 224, 3), dtype=np.float32), training=False)
|
| 79 |
-
|
| 80 |
-
# Load weights from your file name EXACTLY
|
| 81 |
model.load_weights(weights_path)
|
| 82 |
|
| 83 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
|
|
|
|
|
|
|
|
|
| 85 |
try:
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
st.success("β
Model + weights loaded successfully!")
|
| 88 |
except Exception as e:
|
| 89 |
st.error("β Model could not be loaded.")
|
| 90 |
st.exception(e)
|
| 91 |
st.stop()
|
| 92 |
|
| 93 |
-
#
|
| 94 |
-
#
|
| 95 |
-
#
|
| 96 |
-
uploaded_file = st.file_uploader(
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
if uploaded_file is None:
|
| 99 |
st.info("π Please upload an image to start prediction.")
|
| 100 |
-
|
| 101 |
-
img = Image.open(uploaded_file).convert("RGB")
|
| 102 |
-
st.image(img, caption="Uploaded image", use_container_width=True)
|
| 103 |
-
|
| 104 |
-
img_resized = img.resize((224, 224))
|
| 105 |
-
x = np.array(img_resized, dtype=np.float32)
|
| 106 |
-
x = np.expand_dims(x, axis=0)
|
| 107 |
-
x = tf.keras.applications.resnet50.preprocess_input(x)
|
| 108 |
-
|
| 109 |
-
preds = model.predict(x, verbose=0)[0]
|
| 110 |
-
idx = int(np.argmax(preds))
|
| 111 |
-
conf = float(np.max(preds))
|
| 112 |
-
|
| 113 |
-
st.subheader("β
Prediction")
|
| 114 |
-
st.write(f"**Predicted class:** {class_names[idx]}")
|
| 115 |
-
st.write(f"**Confidence:** {conf:.4f}")
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/streamlit_app.py
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
import json
|
| 5 |
import numpy as np
|
| 6 |
import streamlit as st
|
| 7 |
from PIL import Image
|
|
|
|
| 8 |
import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
# ==========================================================
|
| 11 |
# Page config
|
| 12 |
+
# ==========================================================
|
| 13 |
+
st.set_page_config(
|
| 14 |
+
page_title="ResNet50 Image Predictor",
|
| 15 |
+
page_icon="π§ ",
|
| 16 |
+
layout="centered",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
st.title("π§ ResNet50 Image Predictor")
|
| 20 |
+
st.write("Classifies mushroom images using a trained ResNet50 model.")
|
| 21 |
|
| 22 |
+
# ==========================================================
|
| 23 |
+
# Paths (HF Space: repo root is /app, your files are in /app/src)
|
| 24 |
+
# ==========================================================
|
| 25 |
+
MODEL_PATH = "src/new_best3_resnet50.keras"
|
| 26 |
WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5"
|
| 27 |
CLASS_NAMES_PATH = "src/class_names3.json"
|
| 28 |
|
| 29 |
+
IMG_SIZE = (224, 224)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# ==========================================================
|
| 32 |
+
# Helpers
|
| 33 |
+
# ==========================================================
|
| 34 |
+
@st.cache_data
|
| 35 |
+
def load_class_names(path: str):
|
| 36 |
+
if not os.path.exists(path):
|
| 37 |
+
raise FileNotFoundError(f"Missing class names file: {path}")
|
| 38 |
with open(path, "r") as f:
|
| 39 |
+
data = json.load(f)
|
| 40 |
+
|
| 41 |
+
# Accept either list ["a","b"] or dict {"0":"a","1":"b"}
|
| 42 |
+
if isinstance(data, dict):
|
| 43 |
+
# sort by numeric key if possible
|
| 44 |
+
try:
|
| 45 |
+
items = sorted(data.items(), key=lambda kv: int(kv[0]))
|
| 46 |
+
data = [v for _, v in items]
|
| 47 |
+
except Exception:
|
| 48 |
+
data = list(data.values())
|
| 49 |
+
|
| 50 |
+
if not isinstance(data, list) or len(data) == 0:
|
| 51 |
+
raise ValueError("class_names JSON must be a non-empty list (or dict).")
|
| 52 |
+
|
| 53 |
+
return data
|
| 54 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
@st.cache_resource
|
| 57 |
+
def load_model_and_weights(model_path: str, weights_path: str):
|
| 58 |
+
if not os.path.exists(model_path):
|
| 59 |
+
raise FileNotFoundError(f"Missing model file: {model_path}")
|
| 60 |
if not os.path.exists(weights_path):
|
| 61 |
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# Load architecture/model (compile False is fine for inference)
|
| 64 |
+
model = tf.keras.models.load_model(model_path, compile=False)
|
| 65 |
|
| 66 |
+
# Load trained weights (must match the loaded model)
|
|
|
|
|
|
|
|
|
|
| 67 |
model.load_weights(weights_path)
|
| 68 |
|
| 69 |
+
return model
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def preprocess_for_resnet50(pil_img: Image.Image) -> np.ndarray:
|
| 73 |
+
img = pil_img.convert("RGB").resize(IMG_SIZE)
|
| 74 |
+
x = np.array(img, dtype=np.float32)
|
| 75 |
+
x = np.expand_dims(x, axis=0)
|
| 76 |
+
x = tf.keras.applications.resnet50.preprocess_input(x)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ==========================================================
|
| 81 |
+
# Debug info (helpful on HuggingFace)
|
| 82 |
+
# ==========================================================
|
| 83 |
+
with st.expander("π Debug info (HuggingFace check)"):
|
| 84 |
+
try:
|
| 85 |
+
st.write("Working directory:", os.getcwd())
|
| 86 |
+
st.write("Files in ./ :", os.listdir("."))
|
| 87 |
+
if os.path.exists("src"):
|
| 88 |
+
st.write("Files in src/:", os.listdir("src"))
|
| 89 |
+
else:
|
| 90 |
+
st.warning("Folder 'src' not found.")
|
| 91 |
+
st.write("MODEL_PATH exists:", os.path.exists(MODEL_PATH))
|
| 92 |
+
st.write("WEIGHTS_PATH exists:", os.path.exists(WEIGHTS_PATH))
|
| 93 |
+
st.write("CLASS_NAMES exists:", os.path.exists(CLASS_NAMES_PATH))
|
| 94 |
+
st.write("TensorFlow version:", tf.__version__)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
st.exception(e)
|
| 97 |
|
| 98 |
+
# ==========================================================
|
| 99 |
+
# Load artifacts
|
| 100 |
+
# ==========================================================
|
| 101 |
try:
|
| 102 |
+
class_names = load_class_names(CLASS_NAMES_PATH)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
st.error("β Could not load class names.")
|
| 105 |
+
st.exception(e)
|
| 106 |
+
st.stop()
|
| 107 |
+
|
| 108 |
+
# Show available classes
|
| 109 |
+
with st.expander("π§ͺ Available mushroom classes (you can test these)"):
|
| 110 |
+
for i, name in enumerate(class_names):
|
| 111 |
+
st.write(f"**{i}** β {name}")
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
model = load_model_and_weights(MODEL_PATH, WEIGHTS_PATH)
|
| 115 |
st.success("β
Model + weights loaded successfully!")
|
| 116 |
except Exception as e:
|
| 117 |
st.error("β Model could not be loaded.")
|
| 118 |
st.exception(e)
|
| 119 |
st.stop()
|
| 120 |
|
| 121 |
+
# ==========================================================
|
| 122 |
+
# Upload + predict
|
| 123 |
+
# ==========================================================
|
| 124 |
+
uploaded_file = st.file_uploader(
|
| 125 |
+
"Upload a mushroom image",
|
| 126 |
+
type=["jpg", "jpeg", "png", "webp"],
|
| 127 |
+
)
|
| 128 |
|
| 129 |
if uploaded_file is None:
|
| 130 |
st.info("π Please upload an image to start prediction.")
|
| 131 |
+
st.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
# Show image
|
| 134 |
+
img = Image.open(uploaded_file)
|
| 135 |
+
st.image(img, caption="Uploaded image", use_container_width=True)
|
| 136 |
+
|
| 137 |
+
# Predict
|
| 138 |
+
x = preprocess_for_resnet50(img)
|
| 139 |
+
preds = model.predict(x, verbose=0)[0] # shape: (num_classes,)
|
| 140 |
+
|
| 141 |
+
pred_idx = int(np.argmax(preds))
|
| 142 |
+
pred_name = class_names[pred_idx] if pred_idx < len(class_names) else f"class_{pred_idx}"
|
| 143 |
+
confidence = float(preds[pred_idx])
|
| 144 |
+
|
| 145 |
+
st.subheader("β
Prediction")
|
| 146 |
+
st.write(f"**Predicted class index:** {pred_idx}")
|
| 147 |
+
st.write(f"**Predicted mushroom:** π **{pred_name}**")
|
| 148 |
+
st.write(f"**Confidence:** {confidence:.4f}")
|
| 149 |
+
|
| 150 |
+
# Top-3
|
| 151 |
+
st.subheader("π Top-3 predictions")
|
| 152 |
+
top3_idx = np.argsort(preds)[::-1][:3]
|
| 153 |
+
for i in top3_idx:
|
| 154 |
+
name = class_names[i] if i < len(class_names) else f"class_{i}"
|
| 155 |
+
st.write(f"**{name}** β {preds[i]:.4f}")
|
| 156 |
+
|
| 157 |
+
# Full probs (optional)
|
| 158 |
+
with st.expander("All class probabilities"):
|
| 159 |
+
for i, p in enumerate(preds):
|
| 160 |
+
name = class_names[i] if i < len(class_names) else f"class_{i}"
|
| 161 |
+
st.write(f"{i} β {name}: {p:.6f}")
|