EnYa32 commited on
Commit
36b44ac
Β·
verified Β·
1 Parent(s): a42482b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +94 -109
src/streamlit_app.py CHANGED
@@ -1,5 +1,3 @@
1
- # src/streamlit_app.py
2
-
3
  import os
4
  import json
5
  import numpy as np
@@ -7,155 +5,142 @@ import streamlit as st
7
  from PIL import Image
8
  import tensorflow as tf
9
 
10
- # ==========================================================
11
  # Page config
12
- # ==========================================================
13
  st.set_page_config(
14
  page_title="ResNet50 Image Predictor",
15
  page_icon="🧠",
16
- layout="centered",
17
  )
18
 
19
  st.title("🧠 ResNet50 Image Predictor")
20
- st.write("Classifies mushroom images using a trained ResNet50 model.")
21
 
22
- # ==========================================================
23
- # Paths (HF Space: repo root is /app, your files are in /app/src)
24
- # ==========================================================
25
- MODEL_PATH = "src/new_best3_resnet50.keras"
26
- WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5"
27
  CLASS_NAMES_PATH = "src/class_names3.json"
28
 
29
  IMG_SIZE = (224, 224)
30
 
31
- # ==========================================================
32
  # Helpers
33
- # ==========================================================
34
- @st.cache_data
35
- def load_class_names(path: str):
 
 
 
 
 
36
  if not os.path.exists(path):
37
  raise FileNotFoundError(f"Missing class names file: {path}")
38
- with open(path, "r") as f:
39
- data = json.load(f)
40
-
41
- # Accept either list ["a","b"] or dict {"0":"a","1":"b"}
42
- if isinstance(data, dict):
43
- # sort by numeric key if possible
44
- try:
45
- items = sorted(data.items(), key=lambda kv: int(kv[0]))
46
- data = [v for _, v in items]
47
- except Exception:
48
- data = list(data.values())
49
-
50
- if not isinstance(data, list) or len(data) == 0:
51
- raise ValueError("class_names JSON must be a non-empty list (or dict).")
52
-
53
- return data
54
-
 
 
 
 
 
 
 
 
55
 
56
  @st.cache_resource
57
- def load_model_and_weights(model_path: str, weights_path: str):
58
- if not os.path.exists(model_path):
59
- raise FileNotFoundError(f"Missing model file: {model_path}")
 
60
  if not os.path.exists(weights_path):
61
  raise FileNotFoundError(f"Missing weights file: {weights_path}")
62
 
63
- # Load architecture/model (compile False is fine for inference)
64
- model = tf.keras.models.load_model(model_path, compile=False)
65
-
66
- # Load trained weights (must match the loaded model)
67
  model.load_weights(weights_path)
68
 
69
- return model
70
-
71
 
72
- def preprocess_for_resnet50(pil_img: Image.Image) -> np.ndarray:
73
  img = pil_img.convert("RGB").resize(IMG_SIZE)
74
  x = np.array(img, dtype=np.float32)
75
  x = np.expand_dims(x, axis=0)
76
  x = tf.keras.applications.resnet50.preprocess_input(x)
77
  return x
78
 
79
-
80
- # ==========================================================
81
- # Debug info (helpful on HuggingFace)
82
- # ==========================================================
83
  with st.expander("πŸ” Debug info (HuggingFace check)"):
84
- try:
85
- st.write("Working directory:", os.getcwd())
86
- st.write("Files in ./ :", os.listdir("."))
87
- if os.path.exists("src"):
88
- st.write("Files in src/:", os.listdir("src"))
89
- else:
90
- st.warning("Folder 'src' not found.")
91
- st.write("MODEL_PATH exists:", os.path.exists(MODEL_PATH))
92
- st.write("WEIGHTS_PATH exists:", os.path.exists(WEIGHTS_PATH))
93
- st.write("CLASS_NAMES exists:", os.path.exists(CLASS_NAMES_PATH))
94
- st.write("TensorFlow version:", tf.__version__)
95
- except Exception as e:
96
- st.exception(e)
97
-
98
- # ==========================================================
99
- # Load artifacts
100
- # ==========================================================
101
  try:
102
- class_names = load_class_names(CLASS_NAMES_PATH)
 
103
  except Exception as e:
104
- st.error("❌ Could not load class names.")
105
  st.exception(e)
106
  st.stop()
107
 
 
108
  # Show available classes
 
109
  with st.expander("πŸ§ͺ Available mushroom classes (you can test these)"):
 
110
  for i, name in enumerate(class_names):
111
- st.write(f"**{i}** β†’ {name}")
112
-
113
- try:
114
- model = load_model_and_weights(MODEL_PATH, WEIGHTS_PATH)
115
- st.success("βœ… Model + weights loaded successfully!")
116
- except Exception as e:
117
- st.error("❌ Model could not be loaded.")
118
- st.exception(e)
119
- st.stop()
120
 
121
- # ==========================================================
122
- # Upload + predict
123
- # ==========================================================
124
  uploaded_file = st.file_uploader(
125
  "Upload a mushroom image",
126
- type=["jpg", "jpeg", "png", "webp"],
127
  )
128
 
129
  if uploaded_file is None:
130
  st.info("πŸ‘† Please upload an image to start prediction.")
131
- st.stop()
132
-
133
- # Show image
134
- img = Image.open(uploaded_file)
135
- st.image(img, caption="Uploaded image", use_container_width=True)
136
-
137
- # Predict
138
- x = preprocess_for_resnet50(img)
139
- preds = model.predict(x, verbose=0)[0] # shape: (num_classes,)
140
-
141
- pred_idx = int(np.argmax(preds))
142
- pred_name = class_names[pred_idx] if pred_idx < len(class_names) else f"class_{pred_idx}"
143
- confidence = float(preds[pred_idx])
144
-
145
- st.subheader("βœ… Prediction")
146
- st.write(f"**Predicted class index:** {pred_idx}")
147
- st.write(f"**Predicted mushroom:** πŸ„ **{pred_name}**")
148
- st.write(f"**Confidence:** {confidence:.4f}")
149
-
150
- # Top-3
151
- st.subheader("πŸ“Š Top-3 predictions")
152
- top3_idx = np.argsort(preds)[::-1][:3]
153
- for i in top3_idx:
154
- name = class_names[i] if i < len(class_names) else f"class_{i}"
155
- st.write(f"**{name}** β†’ {preds[i]:.4f}")
156
-
157
- # Full probs (optional)
158
- with st.expander("All class probabilities"):
159
- for i, p in enumerate(preds):
160
- name = class_names[i] if i < len(class_names) else f"class_{i}"
161
- st.write(f"{i} β†’ {name}: {p:.6f}")
 
 
 
1
  import os
2
  import json
3
  import numpy as np
 
5
  from PIL import Image
6
  import tensorflow as tf
7
 
8
+ # --------------------------------------------------
9
  # Page config
10
+ # --------------------------------------------------
11
  st.set_page_config(
12
  page_title="ResNet50 Image Predictor",
13
  page_icon="🧠",
14
+ layout="centered"
15
  )
16
 
17
  st.title("🧠 ResNet50 Image Predictor")
18
+ st.write("Classifies mushroom images using a trained ResNet50 model (architecture in code + weights from src/).")
19
 
20
+ # --------------------------------------------------
21
+ # Paths (fixed, HF friendly)
22
+ # --------------------------------------------------
23
+ MODEL_WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5"
 
24
  CLASS_NAMES_PATH = "src/class_names3.json"
25
 
26
  IMG_SIZE = (224, 224)
27
 
28
+ # --------------------------------------------------
29
  # Helpers
30
+ # --------------------------------------------------
31
+ def _safe_listdir(path: str):
32
+ try:
33
+ return sorted(os.listdir(path))
34
+ except Exception as e:
35
+ return f"Could not list dir '{path}': {e}"
36
+
37
+ def load_class_names(path: str) -> list:
38
  if not os.path.exists(path):
39
  raise FileNotFoundError(f"Missing class names file: {path}")
40
+ with open(path, "r", encoding="utf-8") as f:
41
+ names = json.load(f)
42
+ if not isinstance(names, list) or len(names) == 0:
43
+ raise ValueError("class_names JSON must be a non-empty list.")
44
+ return names
45
+
46
+ def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
47
+ base_model = tf.keras.applications.ResNet50(
48
+ weights="imagenet",
49
+ include_top=False,
50
+ input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3),
51
+ )
52
+ base_model.trainable = False
53
+
54
+ model = tf.keras.Sequential([
55
+ base_model,
56
+ tf.keras.layers.GlobalAveragePooling2D(),
57
+ tf.keras.layers.Dense(256, activation="relu"),
58
+ tf.keras.layers.Dropout(0.5),
59
+ tf.keras.layers.Dense(num_classes, activation="softmax"),
60
+ ])
61
+
62
+ # Compile is not required for inference, but harmless:
63
+ model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
64
+ return model
65
 
66
  @st.cache_resource
67
+ def load_model_and_assets(weights_path: str, class_names_path: str):
68
+ class_names = load_class_names(class_names_path)
69
+ model = build_resnet50_classifier(num_classes=len(class_names))
70
+
71
  if not os.path.exists(weights_path):
72
  raise FileNotFoundError(f"Missing weights file: {weights_path}")
73
 
74
+ # Load weights into the exact same architecture
 
 
 
75
  model.load_weights(weights_path)
76
 
77
+ return model, class_names
 
78
 
79
+ def preprocess_image(pil_img: Image.Image) -> np.ndarray:
80
  img = pil_img.convert("RGB").resize(IMG_SIZE)
81
  x = np.array(img, dtype=np.float32)
82
  x = np.expand_dims(x, axis=0)
83
  x = tf.keras.applications.resnet50.preprocess_input(x)
84
  return x
85
 
86
+ # --------------------------------------------------
87
+ # Debug info (HF)
88
+ # --------------------------------------------------
 
89
  with st.expander("πŸ” Debug info (HuggingFace check)"):
90
+ st.write("Files in repo root:", _safe_listdir("."))
91
+ st.write("Files in src/:", _safe_listdir("src"))
92
+ st.write("Weights exists:", os.path.exists(MODEL_WEIGHTS_PATH), "->", MODEL_WEIGHTS_PATH)
93
+ st.write("Class names exists:", os.path.exists(CLASS_NAMES_PATH), "->", CLASS_NAMES_PATH)
94
+ st.write("TensorFlow version:", tf.__version__)
95
+
96
+ # --------------------------------------------------
97
+ # Load model + assets
98
+ # --------------------------------------------------
 
 
 
 
 
 
 
 
99
  try:
100
+ model, class_names = load_model_and_assets(MODEL_WEIGHTS_PATH, CLASS_NAMES_PATH)
101
+ st.success("βœ… Model + weights loaded successfully!")
102
  except Exception as e:
103
+ st.error("❌ Model could not be loaded.")
104
  st.exception(e)
105
  st.stop()
106
 
107
+ # --------------------------------------------------
108
  # Show available classes
109
+ # --------------------------------------------------
110
  with st.expander("πŸ§ͺ Available mushroom classes (you can test these)"):
111
+ st.write(f"Total classes: **{len(class_names)}**")
112
  for i, name in enumerate(class_names):
113
+ st.write(f"**{i}** β€” {name}")
 
 
 
 
 
 
 
 
114
 
115
+ # --------------------------------------------------
116
+ # Image upload + prediction
117
+ # --------------------------------------------------
118
  uploaded_file = st.file_uploader(
119
  "Upload a mushroom image",
120
+ type=["jpg", "jpeg", "png", "webp"]
121
  )
122
 
123
  if uploaded_file is None:
124
  st.info("πŸ‘† Please upload an image to start prediction.")
125
+ else:
126
+ img = Image.open(uploaded_file)
127
+ st.image(img, caption="Uploaded image", use_container_width=True)
128
+
129
+ x = preprocess_image(img)
130
+
131
+ preds = model.predict(x, verbose=0)[0] # shape: (num_classes,)
132
+ pred_idx = int(np.argmax(preds))
133
+ pred_conf = float(preds[pred_idx])
134
+ pred_name = class_names[pred_idx] if pred_idx < len(class_names) else f"Class {pred_idx}"
135
+
136
+ st.subheader("βœ… Prediction")
137
+ st.write(f"**Predicted class index:** {pred_idx}")
138
+ st.write(f"**Predicted class name:** {pred_name}")
139
+ st.write(f"**Confidence:** {pred_conf:.4f}")
140
+
141
+ st.subheader("🏁 Top-3 predictions")
142
+ top3_idx = np.argsort(preds)[::-1][:3]
143
+ for rank, idx in enumerate(top3_idx, start=1):
144
+ name = class_names[int(idx)]
145
+ prob = float(preds[int(idx)])
146
+ st.write(f"{rank}. **{name}** (class {int(idx)}) β€” **{prob*100:.2f}%**")