import streamlit as st
import torch
from torchvision import transforms, models
from PIL import Image
import numpy as np
import pandas as pd
from collections import defaultdict
import os
# Title
st.markdown("
π· Upload & Predict
", unsafe_allow_html=True)
st.markdown("""
### π About This Feature: Upload & Predict
This section of the **DR Assistive Tool** allows users to upload retinal images and get an AI-based prediction of the **Diabetic Retinopathy stage**. It uses a fine-tuned **DenseNet-121** model trained specifically for detecting DR severity levels.
The model classifies the uploaded image into one of the five classes:
- **No DR**
- **Mild**
- **Moderate**
- **Severe**
- **Proliferative DR**
This is especially helpful for:
- Students learning about AI in healthcare
- Researchers testing model robustness
- Clinicians exploring AI-assisted screening tools
The tool also shows **sample images from the test set** for each class. You can use these images to test the modelβs performance and understand what different DR stages look like.
---
### π§ How to Use:
1. π **View sample images** from the test set grouped by DR stage.
- Click the **"π Predict"** button under a sample image to test how the model classifies it.
2. π **Upload your own retinal image** (in JPG or PNG format) using the file uploader.
3. π§ Click the **"Predict"** button after uploading.
- The model will analyze the image and display:
- π― **Predicted DR Stage**
- π **Model confidence score (in %)**
β οΈ *Make sure your image is a clear, centered fundus photograph for best results.*
---
### π Behind the Scenes:
- β
Model: Pretrained **DenseNet-121**
- πΌ Input size: Images are resized to 224Γ224 pixels
- π Normalization: Matches ImageNet pretraining stats
- π¦ Output: Highest probability class from 5 DR categories using **softmax**
*This tool is for educational and research purposes only β not for clinical use.*
""", unsafe_allow_html=True)
# DR class names
class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
# Load sample images from CSV with proper label mapping
@st.cache_data
def load_sample_images_from_csv(csv_path=r'D:\DR_Classification\splits\test_labels.csv'):
df = pd.read_csv(csv_path)
samples = defaultdict(list)
for i in range(5):
class_name = class_names[i]
class_samples = df[df['label'] == i].head(5)
for _, row in class_samples.iterrows():
img_path = row['new_path']
if os.path.exists(img_path):
samples[class_name].append(img_path)
return samples
# Load pretrained model
@st.cache_resource
def load_model():
model = models.densenet121(pretrained=False)
model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names))
model.load_state_dict(torch.load("Model/Pretrained_Densenet-121.pth", map_location='cpu'))
model.eval()
return model
# Image transform function
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prediction function
def predict_image(model, image):
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, pred = torch.max(outputs, 1)
prob = torch.nn.functional.softmax(outputs, dim=1)[0][pred].item() * 100
return class_names[pred.item()], prob
# Create two tabs for better separation of features
tab1, tab2 = st.tabs(["π§ͺ Sample Images", "π€ Upload & Predict"])
with tab1:
st.markdown("### π§ͺ Sample Images from Test Set")
st.markdown("""
#### π About This Feature: Sample Images
In this tab, you can explore sample retinal images from the test set, grouped by their **Diabetic Retinopathy (DR)** stage. This helps you:
- Understand the **visual differences** between DR stages
- Test the modelβs performance on known data
- Get familiar with the modelβs prediction behavior
#### π§ How to Use:
1. Browse the sample images under each DR class.
2. Click **π Predict** under an image to let the AI model analyze it.
3. The result will show:
- π― **Predicted DR stage**
- π **Confidence score**
> *Ideal for researchers and students testing the model with known data.*
""", unsafe_allow_html=True)
sample_images = load_sample_images_from_csv()
for class_name in class_names:
if class_name in sample_images and sample_images[class_name]:
cols = st.columns(5)
for i, img_path in enumerate(sample_images[class_name]):
with cols[i]:
st.image(img_path, use_container_width=True)
if st.button("π Predict", key=f"predict_{img_path}_{i}"):
image = Image.open(img_path).convert('RGB')
model = load_model()
pred_class, prob = predict_image(model, image)
st.success(f"π― Prediction: **{pred_class}** ({prob:.2f}% confidence)")
else:
st.warning(f"β οΈ No images found for **{class_name}**")
with tab2:
st.markdown("### π€ Upload & Predict")
st.markdown("""
#### π About This Feature: Upload & Predict
This tool allows you to upload a **retinal image** and get an **AI-based prediction** of the DR stage using a fine-tuned **DenseNet-121** model.
The model classifies the image into one of:
- No DR
- Mild
- Moderate
- Severe
- Proliferative DR
#### π§ How to Use:
1. π Upload a **clear fundus image** (JPG or PNG).
2. π§ Click **Predict** to let the model analyze it.
3. β
You'll see:
- π― The predicted DR stage
- π Confidence level (in percentage)
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader("π Upload Retinal Image", type=["jpg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption='πΌ Uploaded Image', use_container_width=True)
if st.button("π§ Predict"):
with st.spinner('Analyzing image...'):
model = load_model()
pred_class, prob = predict_image(model, image)
st.success(f"π― Prediction: **{pred_class}** ({prob:.2f}% confidence)")