Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +76 -15
src/streamlit_app.py
CHANGED
|
@@ -4,6 +4,7 @@ from PIL import Image, ImageOps
|
|
| 4 |
import tensorflow as tf
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
from pathlib import Path
|
|
|
|
| 7 |
|
| 8 |
st.set_page_config(page_title="MNIST Digit Recognizer", page_icon="✍️", layout="centered")
|
| 9 |
|
|
@@ -18,28 +19,79 @@ def load_model():
|
|
| 18 |
)
|
| 19 |
return tf.keras.models.load_model(MODEL_PATH)
|
| 20 |
|
| 21 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
Convert uploaded image to MNIST-like 28x28 grayscale tensor (1, 28, 28, 1).
|
|
|
|
| 24 |
"""
|
| 25 |
# 1) Convert to grayscale
|
| 26 |
img = pil_img.convert("L")
|
| 27 |
|
| 28 |
-
# 2) Make
|
| 29 |
-
img = ImageOps.pad(img, (max(img.size), max(img.size)),
|
|
|
|
| 30 |
|
| 31 |
# 3) Resize to 28x28
|
| 32 |
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
arr = np.array(img).astype(np.
|
|
|
|
|
|
|
|
|
|
| 36 |
if arr.mean() < 127:
|
| 37 |
-
arr = 255
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
def plot_probabilities(probs: np.ndarray):
|
| 45 |
plt.figure(figsize=(7, 3))
|
|
@@ -54,11 +106,17 @@ def plot_probabilities(probs: np.ndarray):
|
|
| 54 |
st.title("✍️ MNIST Digit Recognizer")
|
| 55 |
st.write("Upload an image of a handwritten digit (0–9). The model will predict the digit.")
|
| 56 |
|
| 57 |
-
with st.expander("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
st.markdown(
|
| 59 |
"- Use a **single digit** centered in the image.\n"
|
| 60 |
"- Prefer **white background** and **dark digit**.\n"
|
| 61 |
-
"- PNG/JPG works fine."
|
|
|
|
| 62 |
)
|
| 63 |
|
| 64 |
uploaded = st.file_uploader("Upload an image (png/jpg/jpeg)", type=["png", "jpg", "jpeg"])
|
|
@@ -71,22 +129,25 @@ except Exception as e:
|
|
| 71 |
|
| 72 |
if uploaded is not None:
|
| 73 |
pil_img = Image.open(uploaded)
|
|
|
|
| 74 |
st.subheader("Your image")
|
| 75 |
st.image(pil_img, use_container_width=True)
|
| 76 |
|
| 77 |
-
x = preprocess_image(pil_img
|
|
|
|
|
|
|
| 78 |
probs = model.predict(x, verbose=0)[0]
|
| 79 |
pred = int(np.argmax(probs))
|
| 80 |
conf = float(np.max(probs))
|
| 81 |
|
| 82 |
st.subheader("Prediction")
|
| 83 |
-
st.metric("Predicted Digit", pred
|
| 84 |
st.write(f"Confidence: **{conf:.3f}**")
|
| 85 |
|
| 86 |
st.subheader("Model probabilities")
|
| 87 |
plot_probabilities(probs)
|
| 88 |
|
| 89 |
st.subheader("Preprocessed 28×28 image (what the model sees)")
|
| 90 |
-
st.image((x[0].reshape(28, 28) * 255).astype(np.uint8), clamp=True
|
| 91 |
else:
|
| 92 |
st.info("Upload an image to get a prediction.")
|
|
|
|
| 4 |
import tensorflow as tf
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
from pathlib import Path
|
| 7 |
+
import cv2 # NEW
|
| 8 |
|
| 9 |
st.set_page_config(page_title="MNIST Digit Recognizer", page_icon="✍️", layout="centered")
|
| 10 |
|
|
|
|
| 19 |
)
|
| 20 |
return tf.keras.models.load_model(MODEL_PATH)
|
| 21 |
|
| 22 |
+
def center_digit(binary_28: np.ndarray) -> np.ndarray:
|
| 23 |
+
"""
|
| 24 |
+
Centers the digit in a 28x28 binary image using its bounding box.
|
| 25 |
+
Expects values in [0,255] with digit in white (255) on black (0).
|
| 26 |
+
"""
|
| 27 |
+
ys, xs = np.where(binary_28 > 0)
|
| 28 |
+
if len(xs) == 0 or len(ys) == 0:
|
| 29 |
+
return binary_28 # empty image fallback
|
| 30 |
+
|
| 31 |
+
x_min, x_max = xs.min(), xs.max()
|
| 32 |
+
y_min, y_max = ys.min(), ys.max()
|
| 33 |
+
|
| 34 |
+
digit = binary_28[y_min:y_max+1, x_min:x_max+1]
|
| 35 |
+
|
| 36 |
+
# Create blank canvas and paste centered
|
| 37 |
+
canvas = np.zeros((28, 28), dtype=np.uint8)
|
| 38 |
+
h, w = digit.shape
|
| 39 |
+
top = (28 - h) // 2
|
| 40 |
+
left = (28 - w) // 2
|
| 41 |
+
|
| 42 |
+
# Clip if something goes wrong (safety)
|
| 43 |
+
digit = digit[:min(h, 28), :min(w, 28)]
|
| 44 |
+
canvas[top:top+digit.shape[0], left:left+digit.shape[1]] = digit
|
| 45 |
+
return canvas
|
| 46 |
+
|
| 47 |
+
def preprocess_image(pil_img: Image.Image,
|
| 48 |
+
do_threshold: bool = True,
|
| 49 |
+
threshold_value: int = 140,
|
| 50 |
+
do_center: bool = True) -> np.ndarray:
|
| 51 |
"""
|
| 52 |
Convert uploaded image to MNIST-like 28x28 grayscale tensor (1, 28, 28, 1).
|
| 53 |
+
Uses thresholding + optional centering for better robustness.
|
| 54 |
"""
|
| 55 |
# 1) Convert to grayscale
|
| 56 |
img = pil_img.convert("L")
|
| 57 |
|
| 58 |
+
# 2) Make square (pad) to avoid distortion
|
| 59 |
+
img = ImageOps.pad(img, (max(img.size), max(img.size)),
|
| 60 |
+
method=Image.Resampling.LANCZOS, color=255)
|
| 61 |
|
| 62 |
# 3) Resize to 28x28
|
| 63 |
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
| 64 |
|
| 65 |
+
# Convert to numpy
|
| 66 |
+
arr = np.array(img).astype(np.uint8)
|
| 67 |
+
|
| 68 |
+
# 4) Auto-invert if background is dark
|
| 69 |
+
# We want: black background (0), white digit (255) like MNIST (after threshold)
|
| 70 |
if arr.mean() < 127:
|
| 71 |
+
arr = 255 - arr
|
| 72 |
+
|
| 73 |
+
# 5) Mild denoise (optional but helpful for uploaded images)
|
| 74 |
+
arr = cv2.GaussianBlur(arr, (3, 3), 0)
|
| 75 |
+
|
| 76 |
+
# 6) Thresholding (turn to clean black/white)
|
| 77 |
+
if do_threshold:
|
| 78 |
+
_, arr_bin = cv2.threshold(arr, threshold_value, 255, cv2.THRESH_BINARY_INV)
|
| 79 |
+
# After THRESH_BINARY_INV: digit becomes white (255) if it was dark on light background
|
| 80 |
+
# Ensure digit is white on black:
|
| 81 |
+
# arr_bin currently has digit as white and background as black -> perfect
|
| 82 |
+
else:
|
| 83 |
+
# If no threshold: convert to "digit white on black" approx by inverting intensity
|
| 84 |
+
arr_bin = 255 - arr
|
| 85 |
|
| 86 |
+
# 7) Center the digit (optional)
|
| 87 |
+
if do_center:
|
| 88 |
+
arr_bin = center_digit(arr_bin)
|
| 89 |
+
|
| 90 |
+
# 8) Normalize to 0..1 (MNIST style: digit bright)
|
| 91 |
+
arr_norm = arr_bin.astype(np.float32) / 255.0
|
| 92 |
+
|
| 93 |
+
# Return shape (1, 28, 28, 1)
|
| 94 |
+
return arr_norm.reshape(1, 28, 28, 1)
|
| 95 |
|
| 96 |
def plot_probabilities(probs: np.ndarray):
|
| 97 |
plt.figure(figsize=(7, 3))
|
|
|
|
| 106 |
st.title("✍️ MNIST Digit Recognizer")
|
| 107 |
st.write("Upload an image of a handwritten digit (0–9). The model will predict the digit.")
|
| 108 |
|
| 109 |
+
with st.expander("⚙️ Preprocessing options"):
|
| 110 |
+
do_threshold = st.checkbox("Use thresholding (recommended)", value=True)
|
| 111 |
+
threshold_value = st.slider("Threshold value", 0, 255, 140, 5)
|
| 112 |
+
do_center = st.checkbox("Center the digit (recommended)", value=True)
|
| 113 |
+
|
| 114 |
+
with st.expander("ℹ️ Tips for best results"):
|
| 115 |
st.markdown(
|
| 116 |
"- Use a **single digit** centered in the image.\n"
|
| 117 |
"- Prefer **white background** and **dark digit**.\n"
|
| 118 |
+
"- PNG/JPG works fine.\n"
|
| 119 |
+
"- If prediction is wrong, try adjusting the **threshold** slider."
|
| 120 |
)
|
| 121 |
|
| 122 |
uploaded = st.file_uploader("Upload an image (png/jpg/jpeg)", type=["png", "jpg", "jpeg"])
|
|
|
|
| 129 |
|
| 130 |
if uploaded is not None:
|
| 131 |
pil_img = Image.open(uploaded)
|
| 132 |
+
|
| 133 |
st.subheader("Your image")
|
| 134 |
st.image(pil_img, use_container_width=True)
|
| 135 |
|
| 136 |
+
x = preprocess_image(pil_img, do_threshold=do_threshold,
|
| 137 |
+
threshold_value=threshold_value, do_center=do_center)
|
| 138 |
+
|
| 139 |
probs = model.predict(x, verbose=0)[0]
|
| 140 |
pred = int(np.argmax(probs))
|
| 141 |
conf = float(np.max(probs))
|
| 142 |
|
| 143 |
st.subheader("Prediction")
|
| 144 |
+
st.metric("Predicted Digit", pred)
|
| 145 |
st.write(f"Confidence: **{conf:.3f}**")
|
| 146 |
|
| 147 |
st.subheader("Model probabilities")
|
| 148 |
plot_probabilities(probs)
|
| 149 |
|
| 150 |
st.subheader("Preprocessed 28×28 image (what the model sees)")
|
| 151 |
+
st.image((x[0].reshape(28, 28) * 255).astype(np.uint8), clamp=True)
|
| 152 |
else:
|
| 153 |
st.info("Upload an image to get a prediction.")
|