| import gradio as gr
|
| import numpy as np
|
| from PIL import Image
|
| import torch
|
| import test
|
|
|
|
|
| model = test.load_trained_model()
|
|
|
|
|
| def predict_interface(sketch_image):
|
| """处理绘制图像的预测逻辑"""
|
| if sketch_image is None:
|
| return "请先绘制数字", {}
|
|
|
|
|
| img = Image.fromarray(sketch_image).convert('L')
|
|
|
|
|
|
|
| pred_class, probabilities = test.predict_user_image(img, model)
|
|
|
|
|
| prob_dict = {str(i): float(prob) for i, prob in enumerate(probabilities)}
|
| return f"识别结果: {pred_class}", prob_dict
|
|
|
|
|
| def clear_canvas():
|
| """清空画布的函数"""
|
| return None, "识别结果: ", {}
|
|
|
|
|
|
|
| with gr.Blocks(title="手写数字识别") as demo:
|
| gr.Markdown("# 手写数字识别系统")
|
|
|
| with gr.Row():
|
|
|
| sketch = gr.Sketchpad(
|
| label="绘制区域",
|
| shape=(750, 750),
|
| brush_radius=15,
|
| image_mode="L",
|
| invert_colors=True
|
| )
|
|
|
|
|
| with gr.Column():
|
| result_label = gr.Label(label="概率分布", num_top_classes=5)
|
| output_text = gr.Markdown("识别结果: ")
|
|
|
|
|
| with gr.Row():
|
| clear_btn = gr.Button("清除", variant="secondary")
|
| submit_btn = gr.Button("识别", variant="primary")
|
|
|
|
|
| submit_btn.click(
|
| fn=predict_interface,
|
| inputs=sketch,
|
| outputs=[output_text, result_label]
|
| )
|
|
|
| clear_btn.click(
|
| fn=lambda: [None, "识别结果: ", None],
|
| outputs=[sketch, output_text, result_label]
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.launch()
|
|
|