EnYa32 commited on
Commit
08319ea
Β·
verified Β·
1 Parent(s): 8276d56

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +82 -198
src/streamlit_app.py CHANGED
@@ -1,20 +1,22 @@
1
- # src/streamlit_app.py
2
- import os
3
- import json
4
- import pickle
5
- import numpy as np
6
  import streamlit as st
 
7
  from PIL import Image
8
-
9
  import tensorflow as tf
10
- from tensorflow.keras import Sequential
11
- from tensorflow.keras.layers import GlobalAveragePooling2D, BatchNormalization, Dense, Dropout
 
12
  from tensorflow.keras.applications import ResNet50
13
- from tensorflow.keras.applications.resnet50 import preprocess_input
 
 
 
 
 
 
14
 
15
- # -----------------------------
16
  # Page config
17
- # -----------------------------
18
  st.set_page_config(
19
  page_title="ResNet50 Image Predictor",
20
  page_icon="🧠",
@@ -24,95 +26,37 @@ st.set_page_config(
24
  st.title("🧠 ResNet50 Image Predictor")
25
  st.write("Classifies mushroom images using a trained ResNet50 model.")
26
 
27
- # -----------------------------
28
- # Fixed paths (from src/)
29
- # -----------------------------
30
- MODEL_PATH = "src/new_best2_resnet50.keras"
31
- WEIGHTS_PATH = "src/resnet50_weights2.h5"
32
- CLASS_NAMES_PATH = "src/class_names2"
33
 
34
- IMG_SIZE = (224, 224)
35
-
36
- # -----------------------------
37
- # Helpers
38
- # -----------------------------
39
- def file_exists(path: str) -> bool:
40
- try:
41
- return os.path.exists(path)
42
- except Exception:
43
- return False
44
-
45
- def list_src_files():
46
  try:
47
- return os.listdir("src")
 
 
48
  except Exception as e:
49
- return [f"Could not list src/: {e}"]
50
-
51
- def load_class_names(path_base: str):
52
- """
53
- Loads class names from:
54
- - class_names2 (no extension) as: pickle / json / txt
55
- - class_names2.pkl / .json / .txt / .npy
56
- If nothing works, returns None.
57
- """
58
- candidates = [
59
- path_base,
60
- path_base + ".pkl",
61
- path_base + ".pickle",
62
- path_base + ".json",
63
- path_base + ".txt",
64
- path_base + ".npy",
65
- ]
66
-
67
- for p in candidates:
68
- if not file_exists(p):
69
- continue
70
-
71
- # numpy
72
- if p.endswith(".npy"):
73
- try:
74
- arr = np.load(p, allow_pickle=True)
75
- names = arr.tolist()
76
- if isinstance(names, (list, tuple)) and len(names) > 0:
77
- return list(names), p
78
- except Exception:
79
- pass
80
-
81
- # json
82
- if p.endswith(".json"):
83
- try:
84
- with open(p, "r", encoding="utf-8") as f:
85
- names = json.load(f)
86
- if isinstance(names, (list, tuple)) and len(names) > 0:
87
- return list(names), p
88
- except Exception:
89
- pass
90
-
91
- # txt
92
- if p.endswith(".txt"):
93
- try:
94
- with open(p, "r", encoding="utf-8") as f:
95
- names = [line.strip() for line in f.readlines() if line.strip()]
96
- if len(names) > 0:
97
- return names, p
98
- except Exception:
99
- pass
100
-
101
- # try pickle for "no extension" or .pkl/.pickle
102
- try:
103
- with open(p, "rb") as f:
104
- names = pickle.load(f)
105
- if isinstance(names, (list, tuple)) and len(names) > 0:
106
- return list(names), p
107
- except Exception:
108
- pass
109
-
110
- return None, None
111
-
112
- def build_resnet50_head(num_classes: int):
113
- """
114
- Must match your training architecture exactly!
115
- """
116
  base_model = ResNet50(
117
  weights="imagenet",
118
  include_top=False,
@@ -126,120 +70,60 @@ def build_resnet50_head(num_classes: int):
126
  BatchNormalization(),
127
  Dense(256, activation="relu"),
128
  Dropout(0.5),
129
- Dense(num_classes, activation="softmax")
130
  ])
131
- return model
132
 
133
- @st.cache_resource
134
- def load_trained_model():
135
- """
136
- 1) Try loading full .keras model
137
- 2) If fails -> rebuild architecture + load weights
138
- """
139
- # Load class names (needed for weights fallback)
140
- class_names, class_src = load_class_names(CLASS_NAMES_PATH)
141
-
142
- # ---- Try full model first
143
- if file_exists(MODEL_PATH):
144
- try:
145
- m = tf.keras.models.load_model(MODEL_PATH, compile=False)
146
- return m, class_names, f"Loaded full model: {MODEL_PATH}", class_src
147
- except Exception as e:
148
- full_model_error = str(e)
149
- else:
150
- full_model_error = f"Full model not found: {MODEL_PATH}"
151
-
152
- # ---- Fallback: build architecture + load weights
153
- if class_names is None:
154
- raise RuntimeError(
155
- "Full model loading failed AND class names could not be loaded.\n"
156
- f"Full model error: {full_model_error}\n"
157
- "Please upload class_names2 as .pkl or .txt or .json into src/."
158
- )
159
-
160
- if not file_exists(WEIGHTS_PATH):
161
- raise RuntimeError(
162
- "Full model loading failed AND weights file not found.\n"
163
- f"Full model error: {full_model_error}\n"
164
- f"Weights not found: {WEIGHTS_PATH}"
165
- )
166
-
167
- model = build_resnet50_head(num_classes=len(class_names))
168
 
169
- try:
170
- model.load_weights(WEIGHTS_PATH)
171
- return model, class_names, f"Loaded weights: {WEIGHTS_PATH} (fallback)", class_src
172
- except Exception as e:
173
- raise RuntimeError(
174
- "Full model loading failed AND weights loading failed.\n"
175
- f"Full model error: {full_model_error}\n"
176
- f"Weights error: {e}"
177
- )
178
-
179
- # -----------------------------
180
- # Debug panel
181
- # -----------------------------
182
- with st.expander("πŸ” Debug info (HuggingFace check)"):
183
- st.write("Files in src/:")
184
- st.write(list_src_files())
185
- st.write("MODEL_PATH exists:", file_exists(MODEL_PATH))
186
- st.write("WEIGHTS_PATH exists:", file_exists(WEIGHTS_PATH))
187
- st.write("CLASS_NAMES base exists:", file_exists(CLASS_NAMES_PATH))
188
- st.write("TensorFlow:", tf.__version__)
189
-
190
- # -----------------------------
191
- # Load model
192
- # -----------------------------
193
  try:
194
- model, class_names, load_msg, class_src = load_trained_model()
195
- st.success(f"βœ… Model loaded. {load_msg}")
196
- if class_names is not None:
197
- st.caption(f"Class names loaded from: {class_src}")
198
  except Exception as e:
199
  st.error("❌ Model could not be loaded.")
200
  st.exception(e)
201
  st.stop()
202
 
203
- # -----------------------------
204
- # Upload & Predict
205
- # -----------------------------
206
  uploaded_file = st.file_uploader(
207
  "Upload a mushroom image",
208
  type=["jpg", "jpeg", "png", "webp"]
209
  )
210
 
211
- if uploaded_file is None:
212
- st.info("πŸ‘† Please upload an image to start prediction.")
213
- st.stop()
214
-
215
- img = Image.open(uploaded_file).convert("RGB")
216
- st.image(img, caption="Uploaded image", use_container_width=True)
217
-
218
- # preprocess
219
- img_resized = img.resize(IMG_SIZE)
220
- x = np.array(img_resized, dtype=np.float32)
221
- x = np.expand_dims(x, axis=0)
222
- x = preprocess_input(x)
223
-
224
- # predict
225
- preds = model.predict(x, verbose=0)[0]
226
- idx = int(np.argmax(preds))
227
- conf = float(preds[idx])
228
-
229
- # output
230
- st.subheader("βœ… Prediction")
231
-
232
- if class_names is not None and idx < len(class_names):
233
- st.write(f"**Predicted class:** {class_names[idx]}")
234
- else:
235
- st.write(f"**Predicted class index:** {idx}")
236
-
237
- st.write(f"**Confidence:** {conf:.4f}")
 
 
238
 
239
- st.subheader("πŸ“Š Class probabilities")
240
- if class_names is not None and len(class_names) == len(preds):
241
- for name, p in zip(class_names, preds):
242
- st.write(f"{name}: {float(p):.4f}")
243
  else:
244
- for i, p in enumerate(preds):
245
- st.write(f"Class {i}: {float(p):.4f}")
 
 
 
 
 
 
1
  import streamlit as st
2
+ import numpy as np
3
  from PIL import Image
 
4
  import tensorflow as tf
5
+ import json
6
+ import os
7
+
8
  from tensorflow.keras.applications import ResNet50
9
+ from tensorflow.keras.layers import (
10
+ Dense,
11
+ Dropout,
12
+ GlobalAveragePooling2D,
13
+ BatchNormalization
14
+ )
15
+ from tensorflow.keras.models import Sequential
16
 
17
+ # --------------------------------------------------
18
  # Page config
19
+ # --------------------------------------------------
20
  st.set_page_config(
21
  page_title="ResNet50 Image Predictor",
22
  page_icon="🧠",
 
26
  st.title("🧠 ResNet50 Image Predictor")
27
  st.write("Classifies mushroom images using a trained ResNet50 model.")
28
 
29
+ # --------------------------------------------------
30
+ # Paths (FIXED NAMES)
31
+ # --------------------------------------------------
32
+ WEIGHTS_PATH = "src/resnet50_weights1.h5"
33
+ CLASS_NAMES_PATH = "src/class_names.json"
 
34
 
35
+ # --------------------------------------------------
36
+ # Debug info (VERY IMPORTANT on HuggingFace)
37
+ # --------------------------------------------------
38
+ with st.expander("πŸ” Debug info (HuggingFace check)"):
39
+ st.write("Files in src/:")
 
 
 
 
 
 
 
40
  try:
41
+ st.write(os.listdir("src"))
42
+ st.write("Weights exist:", os.path.exists(WEIGHTS_PATH))
43
+ st.write("Classes exist:", os.path.exists(CLASS_NAMES_PATH))
44
  except Exception as e:
45
+ st.error(e)
46
+
47
+ # --------------------------------------------------
48
+ # Load class names
49
+ # --------------------------------------------------
50
+ with open(CLASS_NAMES_PATH, "r") as f:
51
+ class_names = json.load(f)
52
+
53
+ NUM_CLASSES = len(class_names)
54
+
55
+ # --------------------------------------------------
56
+ # Build model + load weights (HF SAFE)
57
+ # --------------------------------------------------
58
+ @st.cache_resource
59
+ def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  base_model = ResNet50(
61
  weights="imagenet",
62
  include_top=False,
 
70
  BatchNormalization(),
71
  Dense(256, activation="relu"),
72
  Dropout(0.5),
73
+ Dense(NUM_CLASSES, activation="softmax")
74
  ])
 
75
 
76
+ model.load_weights(WEIGHTS_PATH)
77
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # --------------------------------------------------
80
+ # Load model safely
81
+ # --------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  try:
83
+ model = load_model()
84
+ st.success("βœ… Model loaded successfully!")
 
 
85
  except Exception as e:
86
  st.error("❌ Model could not be loaded.")
87
  st.exception(e)
88
  st.stop()
89
 
90
+ # --------------------------------------------------
91
+ # Image upload
92
+ # --------------------------------------------------
93
  uploaded_file = st.file_uploader(
94
  "Upload a mushroom image",
95
  type=["jpg", "jpeg", "png", "webp"]
96
  )
97
 
98
+ if uploaded_file is not None:
99
+ img = Image.open(uploaded_file).convert("RGB")
100
+ st.image(img, caption="Uploaded image", use_container_width=True)
101
+
102
+ # --------------------------------------------------
103
+ # Preprocessing (ResNet50)
104
+ # --------------------------------------------------
105
+ img = img.resize((224, 224))
106
+ x = np.array(img, dtype=np.float32)
107
+ x = np.expand_dims(x, axis=0)
108
+ x = tf.keras.applications.resnet50.preprocess_input(x)
109
+
110
+ # --------------------------------------------------
111
+ # Prediction
112
+ # --------------------------------------------------
113
+ preds = model.predict(x, verbose=0)
114
+ pred_idx = int(np.argmax(preds))
115
+ confidence = float(np.max(preds))
116
+
117
+ # --------------------------------------------------
118
+ # Output
119
+ # --------------------------------------------------
120
+ st.subheader("βœ… Prediction")
121
+ st.write(f"**Class:** {class_names[pred_idx]}")
122
+ st.write(f"**Confidence:** {confidence:.4f}")
123
+
124
+ st.subheader("πŸ“Š Class probabilities")
125
+ for i, p in enumerate(preds[0]):
126
+ st.write(f"{class_names[i]}: {p:.4f}")
127
 
 
 
 
 
128
  else:
129
+ st.info("πŸ‘† Please upload an image to start prediction.")