Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| # 1. SETUP MODEL | |
| # We use ResNet18 structure to match your training | |
| model = models.resnet18(weights=None) | |
| model.fc = nn.Linear(model.fc.in_features, 10) # Adjust head to 10 classes | |
| # Load your 98.79% accuracy weights | |
| try: | |
| state_dict = torch.load("fulldigits.pt", map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # 2. PREPROCESSING | |
| # Must use the ImageNet stats you trained with! | |
| transform = transforms.Compose([ | |
| transforms.Lambda(lambda x: x.convert("RGB")), # Force RGB | |
| transforms.Resize((128, 128)), # Match training size | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # 3. PREDICT FUNCTION | |
| def predict(image): | |
| if image is None: return None | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| return {str(i): float(probabilities[i]) for i in range(10)} | |
| # 4. INTERFACE | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Draw or Upload Digit"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Handwritten Digit Recognizer", | |
| description="A ResNet18 model fine-tuned to 98.79% accuracy." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |