Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +57 -14
src/streamlit_app.py
CHANGED
|
@@ -25,6 +25,22 @@ CLASS_NAMES_PATH = "src/class_names4.json"
|
|
| 25 |
|
| 26 |
IMG_SIZE = (224, 224)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# --------------------------------------------------
|
| 29 |
# Helpers
|
| 30 |
# --------------------------------------------------
|
|
@@ -35,12 +51,39 @@ def _safe_listdir(path: str):
|
|
| 35 |
return f"Could not list dir '{path}': {e}"
|
| 36 |
|
| 37 |
def load_class_names(path: str) -> list:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
if not os.path.exists(path):
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if not isinstance(names, list) or len(names) == 0:
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return names
|
| 45 |
|
| 46 |
def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
|
|
@@ -59,7 +102,7 @@ def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
|
|
| 59 |
tf.keras.layers.Dense(num_classes, activation="softmax"),
|
| 60 |
])
|
| 61 |
|
| 62 |
-
# Compile
|
| 63 |
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
| 64 |
return model
|
| 65 |
|
|
@@ -71,9 +114,7 @@ def load_model_and_assets(weights_path: str, class_names_path: str):
|
|
| 71 |
if not os.path.exists(weights_path):
|
| 72 |
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
| 73 |
|
| 74 |
-
# Load weights into the exact same architecture
|
| 75 |
model.load_weights(weights_path)
|
| 76 |
-
|
| 77 |
return model, class_names
|
| 78 |
|
| 79 |
def preprocess_image(pil_img: Image.Image) -> np.ndarray:
|
|
@@ -105,12 +146,12 @@ except Exception as e:
|
|
| 105 |
st.stop()
|
| 106 |
|
| 107 |
# --------------------------------------------------
|
| 108 |
-
# Show available classes
|
| 109 |
# --------------------------------------------------
|
| 110 |
with st.expander("π§ͺ Available mushroom classes (you can test these)"):
|
| 111 |
st.write(f"Total classes: **{len(class_names)}**")
|
| 112 |
for i, name in enumerate(class_names):
|
| 113 |
-
st.write(f"**{i}** β {name}")
|
| 114 |
|
| 115 |
# --------------------------------------------------
|
| 116 |
# Image upload + prediction
|
|
@@ -128,10 +169,11 @@ else:
|
|
| 128 |
|
| 129 |
x = preprocess_image(img)
|
| 130 |
|
| 131 |
-
preds = model.predict(x, verbose=0)[0]
|
| 132 |
pred_idx = int(np.argmax(preds))
|
| 133 |
pred_conf = float(preds[pred_idx])
|
| 134 |
-
|
|
|
|
| 135 |
|
| 136 |
st.subheader("β
Prediction")
|
| 137 |
st.write(f"**Predicted class index:** {pred_idx}")
|
|
@@ -141,6 +183,7 @@ else:
|
|
| 141 |
st.subheader("π Top-3 predictions")
|
| 142 |
top3_idx = np.argsort(preds)[::-1][:3]
|
| 143 |
for rank, idx in enumerate(top3_idx, start=1):
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
| 25 |
|
| 26 |
IMG_SIZE = (224, 224)
|
| 27 |
|
| 28 |
+
# --------------------------------------------------
|
| 29 |
+
# β
FIXED CLASS NAMES (YOUR TABLE) β fallback if JSON is wrong
|
| 30 |
+
# --------------------------------------------------
|
| 31 |
+
CLASS_NAMES_TABLE = [
|
| 32 |
+
"amanita", # 0
|
| 33 |
+
"boletus", # 1
|
| 34 |
+
"chantelle", # 2
|
| 35 |
+
"deterrimus", # 3
|
| 36 |
+
"rufus", # 4
|
| 37 |
+
"torminosus", # 5
|
| 38 |
+
"aurantiacum", # 6
|
| 39 |
+
"procera", # 7
|
| 40 |
+
"involutus", # 8
|
| 41 |
+
"russula", # 9
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
# --------------------------------------------------
|
| 45 |
# Helpers
|
| 46 |
# --------------------------------------------------
|
|
|
|
| 51 |
return f"Could not list dir '{path}': {e}"
|
| 52 |
|
| 53 |
def load_class_names(path: str) -> list:
|
| 54 |
+
"""
|
| 55 |
+
Loads class names from JSON.
|
| 56 |
+
If JSON is missing or invalid or looks like ["0","1","2"...],
|
| 57 |
+
we fall back to CLASS_NAMES_TABLE.
|
| 58 |
+
"""
|
| 59 |
+
# If file not found -> fallback
|
| 60 |
if not os.path.exists(path):
|
| 61 |
+
return CLASS_NAMES_TABLE
|
| 62 |
+
|
| 63 |
+
# Try load JSON
|
| 64 |
+
try:
|
| 65 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 66 |
+
names = json.load(f)
|
| 67 |
+
except Exception:
|
| 68 |
+
return CLASS_NAMES_TABLE
|
| 69 |
+
|
| 70 |
+
# Must be list and non-empty
|
| 71 |
if not isinstance(names, list) or len(names) == 0:
|
| 72 |
+
return CLASS_NAMES_TABLE
|
| 73 |
+
|
| 74 |
+
# If JSON contains only numbers as strings -> it's wrong -> fallback
|
| 75 |
+
# Example: ["0","1","2","3"...]
|
| 76 |
+
if all(isinstance(x, str) and x.strip().isdigit() for x in names):
|
| 77 |
+
return CLASS_NAMES_TABLE
|
| 78 |
+
|
| 79 |
+
# If JSON contains dict like {"amanita":0,...} -> convert to correct order
|
| 80 |
+
if isinstance(names, dict):
|
| 81 |
+
# Expect name->idx mapping
|
| 82 |
+
idx_to_name = {int(v): k for k, v in names.items()}
|
| 83 |
+
ordered = [idx_to_name[i] for i in range(len(idx_to_name))]
|
| 84 |
+
return ordered
|
| 85 |
+
|
| 86 |
+
# Otherwise: assume it's already correct list of names
|
| 87 |
return names
|
| 88 |
|
| 89 |
def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
|
|
|
|
| 102 |
tf.keras.layers.Dense(num_classes, activation="softmax"),
|
| 103 |
])
|
| 104 |
|
| 105 |
+
# Compile not required for inference, but ok
|
| 106 |
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
| 107 |
return model
|
| 108 |
|
|
|
|
| 114 |
if not os.path.exists(weights_path):
|
| 115 |
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
| 116 |
|
|
|
|
| 117 |
model.load_weights(weights_path)
|
|
|
|
| 118 |
return model, class_names
|
| 119 |
|
| 120 |
def preprocess_image(pil_img: Image.Image) -> np.ndarray:
|
|
|
|
| 146 |
st.stop()
|
| 147 |
|
| 148 |
# --------------------------------------------------
|
| 149 |
+
# Show available classes (INDEX + NAME)
|
| 150 |
# --------------------------------------------------
|
| 151 |
with st.expander("π§ͺ Available mushroom classes (you can test these)"):
|
| 152 |
st.write(f"Total classes: **{len(class_names)}**")
|
| 153 |
for i, name in enumerate(class_names):
|
| 154 |
+
st.write(f"**{i}** β **{name}**")
|
| 155 |
|
| 156 |
# --------------------------------------------------
|
| 157 |
# Image upload + prediction
|
|
|
|
| 169 |
|
| 170 |
x = preprocess_image(img)
|
| 171 |
|
| 172 |
+
preds = model.predict(x, verbose=0)[0]
|
| 173 |
pred_idx = int(np.argmax(preds))
|
| 174 |
pred_conf = float(preds[pred_idx])
|
| 175 |
+
|
| 176 |
+
pred_name = class_names[pred_idx] if 0 <= pred_idx < len(class_names) else f"Class {pred_idx}"
|
| 177 |
|
| 178 |
st.subheader("β
Prediction")
|
| 179 |
st.write(f"**Predicted class index:** {pred_idx}")
|
|
|
|
| 183 |
st.subheader("π Top-3 predictions")
|
| 184 |
top3_idx = np.argsort(preds)[::-1][:3]
|
| 185 |
for rank, idx in enumerate(top3_idx, start=1):
|
| 186 |
+
idx = int(idx)
|
| 187 |
+
name = class_names[idx] if 0 <= idx < len(class_names) else f"Class {idx}"
|
| 188 |
+
prob = float(preds[idx])
|
| 189 |
+
st.write(f"{rank}. **{name}** (class {idx}) β **{prob*100:.2f}%**")
|