import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image # 1. SETUP MODEL # We use ResNet18 structure to match your training model = models.resnet18(weights=None) model.fc = nn.Linear(model.fc.in_features, 10) # Adjust head to 10 classes # Load your 98.79% accuracy weights try: state_dict = torch.load("fulldigits.pt", map_location="cpu") model.load_state_dict(state_dict) model.eval() except Exception as e: print(f"Error loading model: {e}") # 2. PREPROCESSING # Must use the ImageNet stats you trained with! transform = transforms.Compose([ transforms.Lambda(lambda x: x.convert("RGB")), # Force RGB transforms.Resize((128, 128)), # Match training size transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 3. PREDICT FUNCTION def predict(image): if image is None: return None img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): output = model(img_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) return {str(i): float(probabilities[i]) for i in range(10)} # 4. INTERFACE demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Draw or Upload Digit"), outputs=gr.Label(num_top_classes=3), title="Handwritten Digit Recognizer", description="A ResNet18 model fine-tuned to 98.79% accuracy." ) if __name__ == "__main__": demo.launch()