EnYa32 commited on
Commit
a42482b
Β·
verified Β·
1 Parent(s): 2b3a2d7

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +126 -86
src/streamlit_app.py CHANGED
@@ -1,121 +1,161 @@
 
 
1
  import os
2
  import json
3
  import numpy as np
4
  import streamlit as st
5
  from PIL import Image
6
-
7
  import tensorflow as tf
8
- from tensorflow.keras.applications import ResNet50
9
- from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
10
- from tensorflow.keras.models import Sequential
11
 
12
- # -----------------------------
13
  # Page config
14
- # -----------------------------
15
- st.set_page_config(page_title="ResNet50 Image Predictor", page_icon="🧠", layout="centered")
 
 
 
 
 
16
  st.title("🧠 ResNet50 Image Predictor")
17
- st.write("Loads ResNet50 architecture and then loads trained weights (HF-safe).")
18
 
19
- # -----------------------------
20
- # Paths (FIXED)
21
- # -----------------------------
 
22
  WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5"
23
  CLASS_NAMES_PATH = "src/class_names3.json"
24
 
25
- # -----------------------------
26
- # Debug info
27
- # -----------------------------
28
- with st.expander("πŸ” Debug info (HuggingFace check)"):
29
- try:
30
- st.write("Files in src/:")
31
- st.write(os.listdir("src"))
32
- st.write("Weights exists:", os.path.exists(WEIGHTS_PATH))
33
- st.write("Class names exists:", os.path.exists(CLASS_NAMES_PATH))
34
- st.write("TF:", tf.__version__)
35
- except Exception as e:
36
- st.error(e)
37
 
38
- # -----------------------------
39
- # Load class names
40
- # -----------------------------
41
- def load_class_names(path: str) -> list:
 
 
 
42
  with open(path, "r") as f:
43
- names = json.load(f)
44
- if not isinstance(names, list) or len(names) == 0:
45
- raise ValueError("class_names3.json must be a non-empty list of class names.")
46
- return names
47
-
48
- # -----------------------------
49
- # Build architecture (MUST match training!)
50
- # -----------------------------
51
- def build_model(num_classes: int) -> tf.keras.Model:
52
- base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
53
- base_model.trainable = False
54
-
55
- model = Sequential([
56
- base_model,
57
- GlobalAveragePooling2D(),
58
- Dense(256, activation="relu"),
59
- Dropout(0.5),
60
- Dense(num_classes, activation="softmax"),
61
- ])
62
- return model
63
 
64
- # -----------------------------
65
- # Load model + weights
66
- # -----------------------------
67
  @st.cache_resource
68
- def load_trained_model(weights_path: str, class_names_path: str):
 
 
69
  if not os.path.exists(weights_path):
70
  raise FileNotFoundError(f"Missing weights file: {weights_path}")
71
- if not os.path.exists(class_names_path):
72
- raise FileNotFoundError(f"Missing class names file: {class_names_path}")
73
 
74
- class_names = load_class_names(class_names_path)
75
- model = build_model(num_classes=len(class_names))
76
 
77
- # IMPORTANT: build the model by calling it once (creates weights shapes)
78
- _ = model(np.zeros((1, 224, 224, 3), dtype=np.float32), training=False)
79
-
80
- # Load weights from your file name EXACTLY
81
  model.load_weights(weights_path)
82
 
83
- return model, class_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
 
 
85
  try:
86
- model, class_names = load_trained_model(WEIGHTS_PATH, CLASS_NAMES_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
87
  st.success("βœ… Model + weights loaded successfully!")
88
  except Exception as e:
89
  st.error("❌ Model could not be loaded.")
90
  st.exception(e)
91
  st.stop()
92
 
93
- # -----------------------------
94
- # Image upload & prediction
95
- # -----------------------------
96
- uploaded_file = st.file_uploader("Upload a mushroom image", type=["jpg", "jpeg", "png", "webp"])
 
 
 
97
 
98
  if uploaded_file is None:
99
  st.info("πŸ‘† Please upload an image to start prediction.")
100
- else:
101
- img = Image.open(uploaded_file).convert("RGB")
102
- st.image(img, caption="Uploaded image", use_container_width=True)
103
-
104
- img_resized = img.resize((224, 224))
105
- x = np.array(img_resized, dtype=np.float32)
106
- x = np.expand_dims(x, axis=0)
107
- x = tf.keras.applications.resnet50.preprocess_input(x)
108
-
109
- preds = model.predict(x, verbose=0)[0]
110
- idx = int(np.argmax(preds))
111
- conf = float(np.max(preds))
112
-
113
- st.subheader("βœ… Prediction")
114
- st.write(f"**Predicted class:** {class_names[idx]}")
115
- st.write(f"**Confidence:** {conf:.4f}")
116
 
117
- st.subheader("πŸ“Š Class probabilities")
118
- # Show top 5
119
- topk = np.argsort(preds)[::-1][:5]
120
- for i in topk:
121
- st.write(f"{class_names[i]}: {preds[i]:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/streamlit_app.py
2
+
3
  import os
4
  import json
5
  import numpy as np
6
  import streamlit as st
7
  from PIL import Image
 
8
  import tensorflow as tf
 
 
 
9
 
10
+ # ==========================================================
11
  # Page config
12
+ # ==========================================================
13
+ st.set_page_config(
14
+ page_title="ResNet50 Image Predictor",
15
+ page_icon="🧠",
16
+ layout="centered",
17
+ )
18
+
19
  st.title("🧠 ResNet50 Image Predictor")
20
+ st.write("Classifies mushroom images using a trained ResNet50 model.")
21
 
22
+ # ==========================================================
23
+ # Paths (HF Space: repo root is /app, your files are in /app/src)
24
+ # ==========================================================
25
+ MODEL_PATH = "src/new_best3_resnet50.keras"
26
  WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5"
27
  CLASS_NAMES_PATH = "src/class_names3.json"
28
 
29
+ IMG_SIZE = (224, 224)
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # ==========================================================
32
+ # Helpers
33
+ # ==========================================================
34
+ @st.cache_data
35
+ def load_class_names(path: str):
36
+ if not os.path.exists(path):
37
+ raise FileNotFoundError(f"Missing class names file: {path}")
38
  with open(path, "r") as f:
39
+ data = json.load(f)
40
+
41
+ # Accept either list ["a","b"] or dict {"0":"a","1":"b"}
42
+ if isinstance(data, dict):
43
+ # sort by numeric key if possible
44
+ try:
45
+ items = sorted(data.items(), key=lambda kv: int(kv[0]))
46
+ data = [v for _, v in items]
47
+ except Exception:
48
+ data = list(data.values())
49
+
50
+ if not isinstance(data, list) or len(data) == 0:
51
+ raise ValueError("class_names JSON must be a non-empty list (or dict).")
52
+
53
+ return data
54
+
 
 
 
 
55
 
 
 
 
56
  @st.cache_resource
57
+ def load_model_and_weights(model_path: str, weights_path: str):
58
+ if not os.path.exists(model_path):
59
+ raise FileNotFoundError(f"Missing model file: {model_path}")
60
  if not os.path.exists(weights_path):
61
  raise FileNotFoundError(f"Missing weights file: {weights_path}")
 
 
62
 
63
+ # Load architecture/model (compile False is fine for inference)
64
+ model = tf.keras.models.load_model(model_path, compile=False)
65
 
66
+ # Load trained weights (must match the loaded model)
 
 
 
67
  model.load_weights(weights_path)
68
 
69
+ return model
70
+
71
+
72
+ def preprocess_for_resnet50(pil_img: Image.Image) -> np.ndarray:
73
+ img = pil_img.convert("RGB").resize(IMG_SIZE)
74
+ x = np.array(img, dtype=np.float32)
75
+ x = np.expand_dims(x, axis=0)
76
+ x = tf.keras.applications.resnet50.preprocess_input(x)
77
+ return x
78
+
79
+
80
+ # ==========================================================
81
+ # Debug info (helpful on HuggingFace)
82
+ # ==========================================================
83
+ with st.expander("πŸ” Debug info (HuggingFace check)"):
84
+ try:
85
+ st.write("Working directory:", os.getcwd())
86
+ st.write("Files in ./ :", os.listdir("."))
87
+ if os.path.exists("src"):
88
+ st.write("Files in src/:", os.listdir("src"))
89
+ else:
90
+ st.warning("Folder 'src' not found.")
91
+ st.write("MODEL_PATH exists:", os.path.exists(MODEL_PATH))
92
+ st.write("WEIGHTS_PATH exists:", os.path.exists(WEIGHTS_PATH))
93
+ st.write("CLASS_NAMES exists:", os.path.exists(CLASS_NAMES_PATH))
94
+ st.write("TensorFlow version:", tf.__version__)
95
+ except Exception as e:
96
+ st.exception(e)
97
 
98
+ # ==========================================================
99
+ # Load artifacts
100
+ # ==========================================================
101
  try:
102
+ class_names = load_class_names(CLASS_NAMES_PATH)
103
+ except Exception as e:
104
+ st.error("❌ Could not load class names.")
105
+ st.exception(e)
106
+ st.stop()
107
+
108
+ # Show available classes
109
+ with st.expander("πŸ§ͺ Available mushroom classes (you can test these)"):
110
+ for i, name in enumerate(class_names):
111
+ st.write(f"**{i}** β†’ {name}")
112
+
113
+ try:
114
+ model = load_model_and_weights(MODEL_PATH, WEIGHTS_PATH)
115
  st.success("βœ… Model + weights loaded successfully!")
116
  except Exception as e:
117
  st.error("❌ Model could not be loaded.")
118
  st.exception(e)
119
  st.stop()
120
 
121
+ # ==========================================================
122
+ # Upload + predict
123
+ # ==========================================================
124
+ uploaded_file = st.file_uploader(
125
+ "Upload a mushroom image",
126
+ type=["jpg", "jpeg", "png", "webp"],
127
+ )
128
 
129
  if uploaded_file is None:
130
  st.info("πŸ‘† Please upload an image to start prediction.")
131
+ st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Show image
134
+ img = Image.open(uploaded_file)
135
+ st.image(img, caption="Uploaded image", use_container_width=True)
136
+
137
+ # Predict
138
+ x = preprocess_for_resnet50(img)
139
+ preds = model.predict(x, verbose=0)[0] # shape: (num_classes,)
140
+
141
+ pred_idx = int(np.argmax(preds))
142
+ pred_name = class_names[pred_idx] if pred_idx < len(class_names) else f"class_{pred_idx}"
143
+ confidence = float(preds[pred_idx])
144
+
145
+ st.subheader("βœ… Prediction")
146
+ st.write(f"**Predicted class index:** {pred_idx}")
147
+ st.write(f"**Predicted mushroom:** πŸ„ **{pred_name}**")
148
+ st.write(f"**Confidence:** {confidence:.4f}")
149
+
150
+ # Top-3
151
+ st.subheader("πŸ“Š Top-3 predictions")
152
+ top3_idx = np.argsort(preds)[::-1][:3]
153
+ for i in top3_idx:
154
+ name = class_names[i] if i < len(class_names) else f"class_{i}"
155
+ st.write(f"**{name}** β†’ {preds[i]:.4f}")
156
+
157
+ # Full probs (optional)
158
+ with st.expander("All class probabilities"):
159
+ for i, p in enumerate(preds):
160
+ name = class_names[i] if i < len(class_names) else f"class_{i}"
161
+ st.write(f"{i} β†’ {name}: {p:.6f}")