EnYa32 commited on
Commit
f303063
·
verified ·
1 Parent(s): 021a336

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +76 -15
src/streamlit_app.py CHANGED
@@ -4,6 +4,7 @@ from PIL import Image, ImageOps
4
  import tensorflow as tf
5
  import matplotlib.pyplot as plt
6
  from pathlib import Path
 
7
 
8
  st.set_page_config(page_title="MNIST Digit Recognizer", page_icon="✍️", layout="centered")
9
 
@@ -18,28 +19,79 @@ def load_model():
18
  )
19
  return tf.keras.models.load_model(MODEL_PATH)
20
 
21
- def preprocess_image(pil_img: Image.Image) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
  Convert uploaded image to MNIST-like 28x28 grayscale tensor (1, 28, 28, 1).
 
24
  """
25
  # 1) Convert to grayscale
26
  img = pil_img.convert("L")
27
 
28
- # 2) Make it square (pad) to avoid distortion
29
- img = ImageOps.pad(img, (max(img.size), max(img.size)), method=Image.Resampling.LANCZOS, color=255)
 
30
 
31
  # 3) Resize to 28x28
32
  img = img.resize((28, 28), Image.Resampling.LANCZOS)
33
 
34
- # 4) Invert if background is dark (MNIST: black background, white digit)
35
- arr = np.array(img).astype(np.float32)
 
 
 
36
  if arr.mean() < 127:
37
- arr = 255.0 - arr # invert
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # 5) Normalize 0..1 and reshape
40
- arr = arr / 255.0
41
- arr = arr.reshape(1, 28, 28, 1)
42
- return arr
 
 
 
 
 
43
 
44
  def plot_probabilities(probs: np.ndarray):
45
  plt.figure(figsize=(7, 3))
@@ -54,11 +106,17 @@ def plot_probabilities(probs: np.ndarray):
54
  st.title("✍️ MNIST Digit Recognizer")
55
  st.write("Upload an image of a handwritten digit (0–9). The model will predict the digit.")
56
 
57
- with st.expander("Requirements for best results"):
 
 
 
 
 
58
  st.markdown(
59
  "- Use a **single digit** centered in the image.\n"
60
  "- Prefer **white background** and **dark digit**.\n"
61
- "- PNG/JPG works fine."
 
62
  )
63
 
64
  uploaded = st.file_uploader("Upload an image (png/jpg/jpeg)", type=["png", "jpg", "jpeg"])
@@ -71,22 +129,25 @@ except Exception as e:
71
 
72
  if uploaded is not None:
73
  pil_img = Image.open(uploaded)
 
74
  st.subheader("Your image")
75
  st.image(pil_img, use_container_width=True)
76
 
77
- x = preprocess_image(pil_img)
 
 
78
  probs = model.predict(x, verbose=0)[0]
79
  pred = int(np.argmax(probs))
80
  conf = float(np.max(probs))
81
 
82
  st.subheader("Prediction")
83
- st.metric("Predicted Digit", pred, delta=None)
84
  st.write(f"Confidence: **{conf:.3f}**")
85
 
86
  st.subheader("Model probabilities")
87
  plot_probabilities(probs)
88
 
89
  st.subheader("Preprocessed 28×28 image (what the model sees)")
90
- st.image((x[0].reshape(28, 28) * 255).astype(np.uint8), clamp=True, use_container_width=False)
91
  else:
92
  st.info("Upload an image to get a prediction.")
 
4
  import tensorflow as tf
5
  import matplotlib.pyplot as plt
6
  from pathlib import Path
7
+ import cv2 # NEW
8
 
9
  st.set_page_config(page_title="MNIST Digit Recognizer", page_icon="✍️", layout="centered")
10
 
 
19
  )
20
  return tf.keras.models.load_model(MODEL_PATH)
21
 
22
+ def center_digit(binary_28: np.ndarray) -> np.ndarray:
23
+ """
24
+ Centers the digit in a 28x28 binary image using its bounding box.
25
+ Expects values in [0,255] with digit in white (255) on black (0).
26
+ """
27
+ ys, xs = np.where(binary_28 > 0)
28
+ if len(xs) == 0 or len(ys) == 0:
29
+ return binary_28 # empty image fallback
30
+
31
+ x_min, x_max = xs.min(), xs.max()
32
+ y_min, y_max = ys.min(), ys.max()
33
+
34
+ digit = binary_28[y_min:y_max+1, x_min:x_max+1]
35
+
36
+ # Create blank canvas and paste centered
37
+ canvas = np.zeros((28, 28), dtype=np.uint8)
38
+ h, w = digit.shape
39
+ top = (28 - h) // 2
40
+ left = (28 - w) // 2
41
+
42
+ # Clip if something goes wrong (safety)
43
+ digit = digit[:min(h, 28), :min(w, 28)]
44
+ canvas[top:top+digit.shape[0], left:left+digit.shape[1]] = digit
45
+ return canvas
46
+
47
+ def preprocess_image(pil_img: Image.Image,
48
+ do_threshold: bool = True,
49
+ threshold_value: int = 140,
50
+ do_center: bool = True) -> np.ndarray:
51
  """
52
  Convert uploaded image to MNIST-like 28x28 grayscale tensor (1, 28, 28, 1).
53
+ Uses thresholding + optional centering for better robustness.
54
  """
55
  # 1) Convert to grayscale
56
  img = pil_img.convert("L")
57
 
58
+ # 2) Make square (pad) to avoid distortion
59
+ img = ImageOps.pad(img, (max(img.size), max(img.size)),
60
+ method=Image.Resampling.LANCZOS, color=255)
61
 
62
  # 3) Resize to 28x28
63
  img = img.resize((28, 28), Image.Resampling.LANCZOS)
64
 
65
+ # Convert to numpy
66
+ arr = np.array(img).astype(np.uint8)
67
+
68
+ # 4) Auto-invert if background is dark
69
+ # We want: black background (0), white digit (255) like MNIST (after threshold)
70
  if arr.mean() < 127:
71
+ arr = 255 - arr
72
+
73
+ # 5) Mild denoise (optional but helpful for uploaded images)
74
+ arr = cv2.GaussianBlur(arr, (3, 3), 0)
75
+
76
+ # 6) Thresholding (turn to clean black/white)
77
+ if do_threshold:
78
+ _, arr_bin = cv2.threshold(arr, threshold_value, 255, cv2.THRESH_BINARY_INV)
79
+ # After THRESH_BINARY_INV: digit becomes white (255) if it was dark on light background
80
+ # Ensure digit is white on black:
81
+ # arr_bin currently has digit as white and background as black -> perfect
82
+ else:
83
+ # If no threshold: convert to "digit white on black" approx by inverting intensity
84
+ arr_bin = 255 - arr
85
 
86
+ # 7) Center the digit (optional)
87
+ if do_center:
88
+ arr_bin = center_digit(arr_bin)
89
+
90
+ # 8) Normalize to 0..1 (MNIST style: digit bright)
91
+ arr_norm = arr_bin.astype(np.float32) / 255.0
92
+
93
+ # Return shape (1, 28, 28, 1)
94
+ return arr_norm.reshape(1, 28, 28, 1)
95
 
96
  def plot_probabilities(probs: np.ndarray):
97
  plt.figure(figsize=(7, 3))
 
106
  st.title("✍️ MNIST Digit Recognizer")
107
  st.write("Upload an image of a handwritten digit (0–9). The model will predict the digit.")
108
 
109
+ with st.expander("Preprocessing options"):
110
+ do_threshold = st.checkbox("Use thresholding (recommended)", value=True)
111
+ threshold_value = st.slider("Threshold value", 0, 255, 140, 5)
112
+ do_center = st.checkbox("Center the digit (recommended)", value=True)
113
+
114
+ with st.expander("ℹ️ Tips for best results"):
115
  st.markdown(
116
  "- Use a **single digit** centered in the image.\n"
117
  "- Prefer **white background** and **dark digit**.\n"
118
+ "- PNG/JPG works fine.\n"
119
+ "- If prediction is wrong, try adjusting the **threshold** slider."
120
  )
121
 
122
  uploaded = st.file_uploader("Upload an image (png/jpg/jpeg)", type=["png", "jpg", "jpeg"])
 
129
 
130
  if uploaded is not None:
131
  pil_img = Image.open(uploaded)
132
+
133
  st.subheader("Your image")
134
  st.image(pil_img, use_container_width=True)
135
 
136
+ x = preprocess_image(pil_img, do_threshold=do_threshold,
137
+ threshold_value=threshold_value, do_center=do_center)
138
+
139
  probs = model.predict(x, verbose=0)[0]
140
  pred = int(np.argmax(probs))
141
  conf = float(np.max(probs))
142
 
143
  st.subheader("Prediction")
144
+ st.metric("Predicted Digit", pred)
145
  st.write(f"Confidence: **{conf:.3f}**")
146
 
147
  st.subheader("Model probabilities")
148
  plot_probabilities(probs)
149
 
150
  st.subheader("Preprocessed 28×28 image (what the model sees)")
151
+ st.image((x[0].reshape(28, 28) * 255).astype(np.uint8), clamp=True)
152
  else:
153
  st.info("Upload an image to get a prediction.")