Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| # Define the model architecture (same as in your original code) | |
| class ImageClassificationBase(nn.Module): | |
| def validation_step(self, batch): | |
| images, labels = batch | |
| out = self(images) | |
| loss = F.cross_entropy(out, labels) | |
| acc = accuracy(out, labels) | |
| return {'val_loss': loss.detach(), 'val_acc': acc} | |
| def validation_epoch_end(self, outputs): | |
| batch_losses = [x['val_loss'] for x in outputs] | |
| epoch_loss = torch.stack(batch_losses).mean() | |
| batch_accs = [x['val_acc'] for x in outputs] | |
| epoch_acc = torch.stack(batch_accs).mean() | |
| return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} | |
| def accuracy(outputs, labels): | |
| _, preds = torch.max(outputs, dim=1) | |
| return torch.tensor(torch.sum(preds == labels).item() / len(preds)) | |
| class Classifier(ImageClassificationBase): | |
| def __init__(self): | |
| super().__init__() | |
| self.network = nn.Sequential( | |
| nn.Conv2d(3, 12, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(3, 3), | |
| nn.Conv2d(12, 15, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(3, 3), | |
| nn.Conv2d(15, 10, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(3, 3), | |
| nn.Flatten(), | |
| nn.Linear(810, 2), | |
| ) | |
| def forward(self, xb): | |
| return self.network(xb) | |
| # Load the trained model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = Classifier().to(device) | |
| # Upload model file to Colab if needed | |
| # from google.colab import files | |
| # uploaded = files.upload() # Upload the .pth file | |
| # Load the model weights | |
| model.load_state_dict(torch.load('PCOS_detection_20_epochs_val_acc_1.0.pth', | |
| map_location=device)) | |
| model.eval() | |
| # Define class names | |
| class_names = ['infected', 'not_infected'] # Update this if your classes are different | |
| # Define the preprocessing transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| # Function to make predictions on an image | |
| def predict_image(img): | |
| # Convert to PIL Image if it's not already | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(img) | |
| # Apply transformations | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| _, preds = torch.max(outputs, 1) | |
| confidence = F.softmax(outputs, dim=1)[0] | |
| # Get class name and confidence | |
| pred_class = class_names[preds[0].item()] | |
| conf_score = confidence[preds[0]].item() | |
| # Prepare result dictionary | |
| result = { | |
| class_names[0]: float(confidence[0]), | |
| class_names[1]: float(confidence[1]) | |
| } | |
| return result | |
| # Create Gradio interface | |
| title = "PCOS Detection from Ultrasound Images" | |
| description = """ | |
| Upload an ultrasound image to detect PCOS (Polycystic Ovary Syndrome). | |
| The model will classify the image as either 'infected' (PCOS positive) or 'not_infected' (PCOS negative). | |
| """ | |
| # Create and launch the interface | |
| demo = gr.Interface( | |
| fn=predict_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=2), | |
| title=title, | |
| description=description, | |
| examples=[ | |
| # You can add example images here if you have them | |
| ] | |
| ) | |
| demo.launch(debug=True, share=True) |