EnYa32 commited on
Commit
f1a5287
Β·
verified Β·
1 Parent(s): 451141c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +242 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,245 @@
1
- import altair as alt
 
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/streamlit_app.py
2
+ import os
3
+ import json
4
+ import pickle
5
  import numpy as np
 
6
  import streamlit as st
7
+ from PIL import Image
8
 
9
+ import tensorflow as tf
10
+ from tensorflow.keras import Sequential
11
+ from tensorflow.keras.layers import GlobalAveragePooling2D, BatchNormalization, Dense, Dropout
12
+ from tensorflow.keras.applications import ResNet50
13
+ from tensorflow.keras.applications.resnet50 import preprocess_input
14
+
15
+ # -----------------------------
16
+ # Page config
17
+ # -----------------------------
18
+ st.set_page_config(
19
+ page_title="ResNet50 Image Predictor",
20
+ page_icon="🧠",
21
+ layout="centered"
22
+ )
23
+
24
+ st.title("🧠 ResNet50 Image Predictor")
25
+ st.write("Classifies mushroom images using a trained ResNet50 model.")
26
+
27
+ # -----------------------------
28
+ # Fixed paths (from src/)
29
+ # -----------------------------
30
+ MODEL_PATH = "src/new_best2_resnet50.keras"
31
+ WEIGHTS_PATH = "src/resnet50_weights2.h5"
32
+ CLASS_NAMES_PATH = "src/class_names2"
33
+
34
+ IMG_SIZE = (224, 224)
35
+
36
+ # -----------------------------
37
+ # Helpers
38
+ # -----------------------------
39
+ def file_exists(path: str) -> bool:
40
+ try:
41
+ return os.path.exists(path)
42
+ except Exception:
43
+ return False
44
+
45
+ def list_src_files():
46
+ try:
47
+ return os.listdir("src")
48
+ except Exception as e:
49
+ return [f"Could not list src/: {e}"]
50
+
51
+ def load_class_names(path_base: str):
52
+ """
53
+ Loads class names from:
54
+ - class_names2 (no extension) as: pickle / json / txt
55
+ - class_names2.pkl / .json / .txt / .npy
56
+ If nothing works, returns None.
57
+ """
58
+ candidates = [
59
+ path_base,
60
+ path_base + ".pkl",
61
+ path_base + ".pickle",
62
+ path_base + ".json",
63
+ path_base + ".txt",
64
+ path_base + ".npy",
65
+ ]
66
+
67
+ for p in candidates:
68
+ if not file_exists(p):
69
+ continue
70
+
71
+ # numpy
72
+ if p.endswith(".npy"):
73
+ try:
74
+ arr = np.load(p, allow_pickle=True)
75
+ names = arr.tolist()
76
+ if isinstance(names, (list, tuple)) and len(names) > 0:
77
+ return list(names), p
78
+ except Exception:
79
+ pass
80
+
81
+ # json
82
+ if p.endswith(".json"):
83
+ try:
84
+ with open(p, "r", encoding="utf-8") as f:
85
+ names = json.load(f)
86
+ if isinstance(names, (list, tuple)) and len(names) > 0:
87
+ return list(names), p
88
+ except Exception:
89
+ pass
90
+
91
+ # txt
92
+ if p.endswith(".txt"):
93
+ try:
94
+ with open(p, "r", encoding="utf-8") as f:
95
+ names = [line.strip() for line in f.readlines() if line.strip()]
96
+ if len(names) > 0:
97
+ return names, p
98
+ except Exception:
99
+ pass
100
+
101
+ # try pickle for "no extension" or .pkl/.pickle
102
+ try:
103
+ with open(p, "rb") as f:
104
+ names = pickle.load(f)
105
+ if isinstance(names, (list, tuple)) and len(names) > 0:
106
+ return list(names), p
107
+ except Exception:
108
+ pass
109
+
110
+ return None, None
111
+
112
+ def build_resnet50_head(num_classes: int):
113
+ """
114
+ Must match your training architecture exactly!
115
+ """
116
+ base_model = ResNet50(
117
+ weights="imagenet",
118
+ include_top=False,
119
+ input_shape=(224, 224, 3)
120
+ )
121
+ base_model.trainable = False
122
+
123
+ model = Sequential([
124
+ base_model,
125
+ GlobalAveragePooling2D(),
126
+ BatchNormalization(),
127
+ Dense(256, activation="relu"),
128
+ Dropout(0.5),
129
+ Dense(num_classes, activation="softmax")
130
+ ])
131
+ return model
132
+
133
+ @st.cache_resource
134
+ def load_trained_model():
135
+ """
136
+ 1) Try loading full .keras model
137
+ 2) If fails -> rebuild architecture + load weights
138
+ """
139
+ # Load class names (needed for weights fallback)
140
+ class_names, class_src = load_class_names(CLASS_NAMES_PATH)
141
+
142
+ # ---- Try full model first
143
+ if file_exists(MODEL_PATH):
144
+ try:
145
+ m = tf.keras.models.load_model(MODEL_PATH, compile=False)
146
+ return m, class_names, f"Loaded full model: {MODEL_PATH}", class_src
147
+ except Exception as e:
148
+ full_model_error = str(e)
149
+ else:
150
+ full_model_error = f"Full model not found: {MODEL_PATH}"
151
+
152
+ # ---- Fallback: build architecture + load weights
153
+ if class_names is None:
154
+ raise RuntimeError(
155
+ "Full model loading failed AND class names could not be loaded.\n"
156
+ f"Full model error: {full_model_error}\n"
157
+ "Please upload class_names2 as .pkl or .txt or .json into src/."
158
+ )
159
+
160
+ if not file_exists(WEIGHTS_PATH):
161
+ raise RuntimeError(
162
+ "Full model loading failed AND weights file not found.\n"
163
+ f"Full model error: {full_model_error}\n"
164
+ f"Weights not found: {WEIGHTS_PATH}"
165
+ )
166
+
167
+ model = build_resnet50_head(num_classes=len(class_names))
168
+
169
+ try:
170
+ model.load_weights(WEIGHTS_PATH)
171
+ return model, class_names, f"Loaded weights: {WEIGHTS_PATH} (fallback)", class_src
172
+ except Exception as e:
173
+ raise RuntimeError(
174
+ "Full model loading failed AND weights loading failed.\n"
175
+ f"Full model error: {full_model_error}\n"
176
+ f"Weights error: {e}"
177
+ )
178
+
179
+ # -----------------------------
180
+ # Debug panel
181
+ # -----------------------------
182
+ with st.expander("πŸ” Debug info (HuggingFace check)"):
183
+ st.write("Files in src/:")
184
+ st.write(list_src_files())
185
+ st.write("MODEL_PATH exists:", file_exists(MODEL_PATH))
186
+ st.write("WEIGHTS_PATH exists:", file_exists(WEIGHTS_PATH))
187
+ st.write("CLASS_NAMES base exists:", file_exists(CLASS_NAMES_PATH))
188
+ st.write("TensorFlow:", tf.__version__)
189
+
190
+ # -----------------------------
191
+ # Load model
192
+ # -----------------------------
193
+ try:
194
+ model, class_names, load_msg, class_src = load_trained_model()
195
+ st.success(f"βœ… Model loaded. {load_msg}")
196
+ if class_names is not None:
197
+ st.caption(f"Class names loaded from: {class_src}")
198
+ except Exception as e:
199
+ st.error("❌ Model could not be loaded.")
200
+ st.exception(e)
201
+ st.stop()
202
+
203
+ # -----------------------------
204
+ # Upload & Predict
205
+ # -----------------------------
206
+ uploaded_file = st.file_uploader(
207
+ "Upload a mushroom image",
208
+ type=["jpg", "jpeg", "png", "webp"]
209
+ )
210
+
211
+ if uploaded_file is None:
212
+ st.info("πŸ‘† Please upload an image to start prediction.")
213
+ st.stop()
214
+
215
+ img = Image.open(uploaded_file).convert("RGB")
216
+ st.image(img, caption="Uploaded image", use_container_width=True)
217
+
218
+ # preprocess
219
+ img_resized = img.resize(IMG_SIZE)
220
+ x = np.array(img_resized, dtype=np.float32)
221
+ x = np.expand_dims(x, axis=0)
222
+ x = preprocess_input(x)
223
+
224
+ # predict
225
+ preds = model.predict(x, verbose=0)[0]
226
+ idx = int(np.argmax(preds))
227
+ conf = float(preds[idx])
228
+
229
+ # output
230
+ st.subheader("βœ… Prediction")
231
+
232
+ if class_names is not None and idx < len(class_names):
233
+ st.write(f"**Predicted class:** {class_names[idx]}")
234
+ else:
235
+ st.write(f"**Predicted class index:** {idx}")
236
+
237
+ st.write(f"**Confidence:** {conf:.4f}")
238
+
239
+ st.subheader("πŸ“Š Class probabilities")
240
+ if class_names is not None and len(class_names) == len(preds):
241
+ for name, p in zip(class_names, preds):
242
+ st.write(f"{name}: {float(p):.4f}")
243
+ else:
244
+ for i, p in enumerate(preds):
245
+ st.write(f"Class {i}: {float(p):.4f}")