EnYa32 commited on
Commit
a798df7
Β·
verified Β·
1 Parent(s): 16d2b3d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +57 -14
src/streamlit_app.py CHANGED
@@ -25,6 +25,22 @@ CLASS_NAMES_PATH = "src/class_names4.json"
25
 
26
  IMG_SIZE = (224, 224)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # --------------------------------------------------
29
  # Helpers
30
  # --------------------------------------------------
@@ -35,12 +51,39 @@ def _safe_listdir(path: str):
35
  return f"Could not list dir '{path}': {e}"
36
 
37
  def load_class_names(path: str) -> list:
 
 
 
 
 
 
38
  if not os.path.exists(path):
39
- raise FileNotFoundError(f"Missing class names file: {path}")
40
- with open(path, "r", encoding="utf-8") as f:
41
- names = json.load(f)
 
 
 
 
 
 
 
42
  if not isinstance(names, list) or len(names) == 0:
43
- raise ValueError("class_names JSON must be a non-empty list.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  return names
45
 
46
  def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
@@ -59,7 +102,7 @@ def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
59
  tf.keras.layers.Dense(num_classes, activation="softmax"),
60
  ])
61
 
62
- # Compile is not required for inference, but harmless:
63
  model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
64
  return model
65
 
@@ -71,9 +114,7 @@ def load_model_and_assets(weights_path: str, class_names_path: str):
71
  if not os.path.exists(weights_path):
72
  raise FileNotFoundError(f"Missing weights file: {weights_path}")
73
 
74
- # Load weights into the exact same architecture
75
  model.load_weights(weights_path)
76
-
77
  return model, class_names
78
 
79
  def preprocess_image(pil_img: Image.Image) -> np.ndarray:
@@ -105,12 +146,12 @@ except Exception as e:
105
  st.stop()
106
 
107
  # --------------------------------------------------
108
- # Show available classes
109
  # --------------------------------------------------
110
  with st.expander("πŸ§ͺ Available mushroom classes (you can test these)"):
111
  st.write(f"Total classes: **{len(class_names)}**")
112
  for i, name in enumerate(class_names):
113
- st.write(f"**{i}** β€” {name}")
114
 
115
  # --------------------------------------------------
116
  # Image upload + prediction
@@ -128,10 +169,11 @@ else:
128
 
129
  x = preprocess_image(img)
130
 
131
- preds = model.predict(x, verbose=0)[0] # shape: (num_classes,)
132
  pred_idx = int(np.argmax(preds))
133
  pred_conf = float(preds[pred_idx])
134
- pred_name = class_names[pred_idx] if pred_idx < len(class_names) else f"Class {pred_idx}"
 
135
 
136
  st.subheader("βœ… Prediction")
137
  st.write(f"**Predicted class index:** {pred_idx}")
@@ -141,6 +183,7 @@ else:
141
  st.subheader("🏁 Top-3 predictions")
142
  top3_idx = np.argsort(preds)[::-1][:3]
143
  for rank, idx in enumerate(top3_idx, start=1):
144
- name = class_names[int(idx)]
145
- prob = float(preds[int(idx)])
146
- st.write(f"{rank}. **{name}** (class {int(idx)}) β€” **{prob*100:.2f}%**")
 
 
25
 
26
  IMG_SIZE = (224, 224)
27
 
28
+ # --------------------------------------------------
29
+ # βœ… FIXED CLASS NAMES (YOUR TABLE) β€” fallback if JSON is wrong
30
+ # --------------------------------------------------
31
+ CLASS_NAMES_TABLE = [
32
+ "amanita", # 0
33
+ "boletus", # 1
34
+ "chantelle", # 2
35
+ "deterrimus", # 3
36
+ "rufus", # 4
37
+ "torminosus", # 5
38
+ "aurantiacum", # 6
39
+ "procera", # 7
40
+ "involutus", # 8
41
+ "russula", # 9
42
+ ]
43
+
44
  # --------------------------------------------------
45
  # Helpers
46
  # --------------------------------------------------
 
51
  return f"Could not list dir '{path}': {e}"
52
 
53
  def load_class_names(path: str) -> list:
54
+ """
55
+ Loads class names from JSON.
56
+ If JSON is missing or invalid or looks like ["0","1","2"...],
57
+ we fall back to CLASS_NAMES_TABLE.
58
+ """
59
+ # If file not found -> fallback
60
  if not os.path.exists(path):
61
+ return CLASS_NAMES_TABLE
62
+
63
+ # Try load JSON
64
+ try:
65
+ with open(path, "r", encoding="utf-8") as f:
66
+ names = json.load(f)
67
+ except Exception:
68
+ return CLASS_NAMES_TABLE
69
+
70
+ # Must be list and non-empty
71
  if not isinstance(names, list) or len(names) == 0:
72
+ return CLASS_NAMES_TABLE
73
+
74
+ # If JSON contains only numbers as strings -> it's wrong -> fallback
75
+ # Example: ["0","1","2","3"...]
76
+ if all(isinstance(x, str) and x.strip().isdigit() for x in names):
77
+ return CLASS_NAMES_TABLE
78
+
79
+ # If JSON contains dict like {"amanita":0,...} -> convert to correct order
80
+ if isinstance(names, dict):
81
+ # Expect name->idx mapping
82
+ idx_to_name = {int(v): k for k, v in names.items()}
83
+ ordered = [idx_to_name[i] for i in range(len(idx_to_name))]
84
+ return ordered
85
+
86
+ # Otherwise: assume it's already correct list of names
87
  return names
88
 
89
  def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
 
102
  tf.keras.layers.Dense(num_classes, activation="softmax"),
103
  ])
104
 
105
+ # Compile not required for inference, but ok
106
  model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
107
  return model
108
 
 
114
  if not os.path.exists(weights_path):
115
  raise FileNotFoundError(f"Missing weights file: {weights_path}")
116
 
 
117
  model.load_weights(weights_path)
 
118
  return model, class_names
119
 
120
  def preprocess_image(pil_img: Image.Image) -> np.ndarray:
 
146
  st.stop()
147
 
148
  # --------------------------------------------------
149
+ # Show available classes (INDEX + NAME)
150
  # --------------------------------------------------
151
  with st.expander("πŸ§ͺ Available mushroom classes (you can test these)"):
152
  st.write(f"Total classes: **{len(class_names)}**")
153
  for i, name in enumerate(class_names):
154
+ st.write(f"**{i}** β€” **{name}**")
155
 
156
  # --------------------------------------------------
157
  # Image upload + prediction
 
169
 
170
  x = preprocess_image(img)
171
 
172
+ preds = model.predict(x, verbose=0)[0]
173
  pred_idx = int(np.argmax(preds))
174
  pred_conf = float(preds[pred_idx])
175
+
176
+ pred_name = class_names[pred_idx] if 0 <= pred_idx < len(class_names) else f"Class {pred_idx}"
177
 
178
  st.subheader("βœ… Prediction")
179
  st.write(f"**Predicted class index:** {pred_idx}")
 
183
  st.subheader("🏁 Top-3 predictions")
184
  top3_idx = np.argsort(preds)[::-1][:3]
185
  for rank, idx in enumerate(top3_idx, start=1):
186
+ idx = int(idx)
187
+ name = class_names[idx] if 0 <= idx < len(class_names) else f"Class {idx}"
188
+ prob = float(preds[idx])
189
+ st.write(f"{rank}. **{name}** (class {idx}) β€” **{prob*100:.2f}%**")