import streamlit as st import torch from torchvision import transforms, models from PIL import Image import numpy as np import pandas as pd from collections import defaultdict import os # Title st.markdown("

πŸ“· Upload & Predict

", unsafe_allow_html=True) st.markdown(""" ### πŸ“– About This Feature: Upload & Predict This section of the **DR Assistive Tool** allows users to upload retinal images and get an AI-based prediction of the **Diabetic Retinopathy stage**. It uses a fine-tuned **DenseNet-121** model trained specifically for detecting DR severity levels. The model classifies the uploaded image into one of the five classes: - **No DR** - **Mild** - **Moderate** - **Severe** - **Proliferative DR** This is especially helpful for: - Students learning about AI in healthcare - Researchers testing model robustness - Clinicians exploring AI-assisted screening tools The tool also shows **sample images from the test set** for each class. You can use these images to test the model’s performance and understand what different DR stages look like. --- ### 🧭 How to Use: 1. πŸ” **View sample images** from the test set grouped by DR stage. - Click the **"πŸ” Predict"** button under a sample image to test how the model classifies it. 2. πŸ“ **Upload your own retinal image** (in JPG or PNG format) using the file uploader. 3. 🧠 Click the **"Predict"** button after uploading. - The model will analyze the image and display: - 🎯 **Predicted DR Stage** - πŸ“Š **Model confidence score (in %)** ⚠️ *Make sure your image is a clear, centered fundus photograph for best results.* --- ### πŸ›  Behind the Scenes: - βœ… Model: Pretrained **DenseNet-121** - πŸ–Ό Input size: Images are resized to 224Γ—224 pixels - πŸ”„ Normalization: Matches ImageNet pretraining stats - πŸ“¦ Output: Highest probability class from 5 DR categories using **softmax** *This tool is for educational and research purposes only β€” not for clinical use.* """, unsafe_allow_html=True) # DR class names class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'] # Load sample images from CSV with proper label mapping @st.cache_data def load_sample_images_from_csv(csv_path=r'D:\DR_Classification\splits\test_labels.csv'): df = pd.read_csv(csv_path) samples = defaultdict(list) for i in range(5): class_name = class_names[i] class_samples = df[df['label'] == i].head(5) for _, row in class_samples.iterrows(): img_path = row['new_path'] if os.path.exists(img_path): samples[class_name].append(img_path) return samples # Load pretrained model @st.cache_resource def load_model(): model = models.densenet121(pretrained=False) model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names)) model.load_state_dict(torch.load("Model/Pretrained_Densenet-121.pth", map_location='cpu')) model.eval() return model # Image transform function transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Prediction function def predict_image(model, image): img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(img_tensor) _, pred = torch.max(outputs, 1) prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100 return class_names[pred.item()], prob # Create two tabs for better separation of features tab1, tab2 = st.tabs(["πŸ§ͺ Sample Images", "πŸ“€ Upload & Predict"]) with tab1: st.markdown("### πŸ§ͺ Sample Images from Test Set") st.markdown(""" #### πŸ“– About This Feature: Sample Images In this tab, you can explore sample retinal images from the test set, grouped by their **Diabetic Retinopathy (DR)** stage. This helps you: - Understand the **visual differences** between DR stages - Test the model’s performance on known data - Get familiar with the model’s prediction behavior #### 🧭 How to Use: 1. Browse the sample images under each DR class. 2. Click **πŸ” Predict** under an image to let the AI model analyze it. 3. The result will show: - 🎯 **Predicted DR stage** - πŸ“Š **Confidence score** > *Ideal for researchers and students testing the model with known data.* """, unsafe_allow_html=True) sample_images = load_sample_images_from_csv() for class_name in class_names: if class_name in sample_images and sample_images[class_name]: cols = st.columns(5) for i, img_path in enumerate(sample_images[class_name]): with cols[i]: st.image(img_path, use_container_width=True) if st.button("πŸ” Predict", key=f"predict_{img_path}_{i}"): image = Image.open(img_path).convert('RGB') model = load_model() pred_class, prob = predict_image(model, image) st.success(f"🎯 Prediction: **{pred_class}** ({prob:.2f}% confidence)") else: st.warning(f"⚠️ No images found for **{class_name}**") with tab2: st.markdown("### πŸ“€ Upload & Predict") st.markdown(""" #### πŸ“– About This Feature: Upload & Predict This tool allows you to upload a **retinal image** and get an **AI-based prediction** of the DR stage using a fine-tuned **DenseNet-121** model. The model classifies the image into one of: - No DR - Mild - Moderate - Severe - Proliferative DR #### 🧭 How to Use: 1. πŸ“ Upload a **clear fundus image** (JPG or PNG). 2. 🧠 Click **Predict** to let the model analyze it. 3. βœ… You'll see: - 🎯 The predicted DR stage - πŸ“Š Confidence level (in percentage) """, unsafe_allow_html=True) uploaded_file = st.file_uploader("πŸ“ Upload Retinal Image", type=["jpg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert('RGB') st.image(image, caption='πŸ–Ό Uploaded Image', use_container_width=True) if st.button("🧠 Predict"): with st.spinner('Analyzing image...'): model = load_model() pred_class, prob = predict_image(model, image) st.success(f"🎯 Prediction: **{pred_class}** ({prob:.2f}% confidence)")