| import sys
|
| from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, QPlainTextEdit, QMainWindow,QHBoxLayout
|
| from PyQt5.QtCore import Qt, QPoint
|
| from PyQt5.QtGui import QPainter, QImage, QColor,QPen
|
| import test
|
| import numpy as np
|
| import torch
|
|
|
|
|
| from test import load_trained_model
|
|
|
| model = load_trained_model()
|
|
|
|
|
| class DrawingArea(QWidget):
|
| def __init__(self):
|
| super().__init__()
|
| self.setFixedSize(750 + 40, 750 + 40)
|
| self.drawing = False
|
| self.last_pos = QPoint()
|
|
|
|
|
| self.image = QImage(28, 28, QImage.Format_RGB888)
|
| self.image.fill(Qt.black)
|
|
|
|
|
| self.cell_size = 750 // 28
|
|
|
| def paintEvent(self, event):
|
| painter = QPainter(self)
|
| painter.setRenderHint(QPainter.Antialiasing, False)
|
|
|
|
|
| scaled_img = self.image.scaled(750, 750, Qt.KeepAspectRatio, Qt.FastTransformation)
|
| painter.drawImage(20, 20, scaled_img)
|
|
|
|
|
| painter.setPen(QPen(Qt.gray, 1, Qt.SolidLine))
|
| for i in range(29):
|
|
|
| painter.drawLine(20, 20 + i * self.cell_size,
|
| 20 + 750, 20 + i * self.cell_size)
|
|
|
| painter.drawLine(20 + i * self.cell_size, 20,
|
| 20 + i * self.cell_size, 20 + 750)
|
|
|
| def mousePressEvent(self, event):
|
| if event.button() == Qt.LeftButton:
|
| self.drawing = True
|
| self.handleDrawing(event.pos())
|
|
|
| def mouseMoveEvent(self, event):
|
| if self.drawing:
|
| self.handleDrawing(event.pos())
|
|
|
| def mouseReleaseEvent(self, event):
|
| if event.button() == Qt.LeftButton:
|
| self.drawing = False
|
|
|
| def handleDrawing(self, pos):
|
|
|
| x = pos.x() - 20
|
| y = pos.y() - 20
|
|
|
|
|
| if 0 <= x < 750 and 0 <= y < 750:
|
|
|
| col = x // self.cell_size
|
| row = y // self.cell_size
|
|
|
|
|
| if (col, row) != self.last_pos:
|
| self.last_pos = (col, row)
|
| painter = QPainter(self.image)
|
| painter.setPen(Qt.white)
|
| painter.drawPoint(col, row)
|
| self.update()
|
|
|
| def get_image(self):
|
| return self.image.convertToFormat(QImage.Format_Grayscale8)
|
|
|
| def clear_image(self):
|
| self.image.fill(Qt.black)
|
| self.update()
|
|
|
|
|
| class MainWindow(QMainWindow):
|
| def __init__(self):
|
| super().__init__()
|
| self.init_ui()
|
|
|
| def init_ui(self):
|
|
|
| self.setWindowTitle("手写识别")
|
| self.setFixedSize(850, 950)
|
|
|
|
|
| main_widget = QWidget()
|
| self.setCentralWidget(main_widget)
|
| layout = QVBoxLayout(main_widget)
|
|
|
|
|
| self.drawing_area = DrawingArea()
|
| layout.addWidget(self.drawing_area)
|
|
|
|
|
| btn_layout = QHBoxLayout()
|
| self.clear_btn = QPushButton("清除")
|
| self.recognize_btn = QPushButton("识别")
|
| btn_layout.addWidget(self.clear_btn)
|
| btn_layout.addWidget(self.recognize_btn)
|
|
|
|
|
| self.prob_label = QLabel("概率分布:")
|
| self.result_label = QLabel("识别结果:")
|
|
|
|
|
| layout.addLayout(btn_layout)
|
| layout.addWidget(self.prob_label)
|
| layout.addWidget(self.result_label)
|
|
|
| self.clear_btn.clicked.connect(self.drawing_area.clear_image)
|
| self.recognize_btn.clicked.connect(self.recognize)
|
|
|
|
|
|
|
| def recognize(self):
|
|
|
| qimg = self.drawing_area.get_image()
|
|
|
| pred_class, probabilities = test.predict_user_image(qimg,model)
|
|
|
| self.prob_label.setText(f"概率分布: {probabilities}")
|
| self.result_label.setText(f"识别结果: {pred_class}")
|
|
|
|
|
| if __name__ == "__main__":
|
| app = QApplication(sys.argv)
|
| window = MainWindow()
|
| window.show()
|
| sys.exit(app.exec_())
|
|
|