Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| # 1. Load your saved model | |
| model = tf.keras.models.load_model('digit_recognizer.keras') | |
| def classify_digit(image): | |
| # Error handling: if no image is provided | |
| if image is None: | |
| return None | |
| # --- PREPROCESSING --- | |
| # Convert to numpy array if it isn't already | |
| image = np.array(image) | |
| # 1. Handle Color Channels | |
| # If image has 4 channels (RGBA) from sketchpad, convert to Gray | |
| if image.shape[-1] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY) | |
| # If image has 3 channels (RGB) from upload, convert to Gray | |
| elif image.shape[-1] == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| # 2. Resize to 28x28 | |
| # We use INTER_AREA for shrinking which preserves details better than default | |
| image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA) | |
| # 3. Invert Colors (Critical Step) | |
| # MNIST models expect White Text on Black Background. | |
| # If the image is mostly bright (like white paper), we must invert it. | |
| avg_brightness = np.mean(image) | |
| if avg_brightness > 127: # If the image is mostly white/light | |
| image = 255 - image # Invert to black background | |
| # 4. Reshape for Model | |
| # (1 sample, 28 height, 28 width, 1 channel) | |
| image = image.reshape(1, 28, 28, 1) | |
| # 5. Normalize (0 to 1) | |
| image = image / 255.0 | |
| # --- PREDICTION --- | |
| prediction = model.predict(image).flatten() | |
| return {str(i): float(prediction[i]) for i in range(10)} | |
| # --- GRADIO INTERFACE --- | |
| # sources=["upload", "canvas"] enables both file upload and drawing | |
| interface = gr.Interface( | |
| fn=classify_digit, | |
| inputs=gr.Image( | |
| type="numpy", | |
| label="Draw or Upload Digit", | |
| image_mode="L", # "L" attempts to convert to grayscale immediately | |
| sources=["upload", "canvas"], | |
| height=400, | |
| width=400 | |
| ), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Handwritten Digit Recognizer", | |
| description="Draw a digit on the canvas OR upload a photo of a digit. The model will guess what it is." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |