| import cv2 |
| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| import streamlit as st |
| from typing import Tuple |
| from fpdf import FPDF |
| import io |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| CLASS_NAMES = ["Mild", "Moderate", "No_DR", "Proliferate_DR", "Severe"] |
| LESION_COLORS = { |
| 0: [0, 0, 0], |
| 1: [255, 255, 0], |
| 2: [255, 0, 0] |
| } |
| UK_GRADES = { |
| "No_DR": "R0 - No retinopathy", |
| "Mild": "R1 - Background DR", |
| "Moderate": "R1 - Background DR", |
| "Severe": "R2 - Pre-proliferative DR", |
| "Proliferate_DR": "R3 - Proliferative DR" |
| } |
|
|
|
|
| |
| class UNet(nn.Module): |
| def __init__(self, input_channels=3, num_classes=3): |
| super(UNet, self).__init__() |
| |
| def conv_block(in_channels, out_channels): |
| return nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.BatchNorm2d(out_channels), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.BatchNorm2d(out_channels), |
| ) |
|
|
| self.encoder1 = conv_block(input_channels, 32) |
| self.pool1 = nn.MaxPool2d(2) |
| self.encoder2 = conv_block(32, 64) |
| self.pool2 = nn.MaxPool2d(2) |
| self.encoder3 = conv_block(64, 128) |
| self.pool3 = nn.MaxPool2d(2) |
|
|
| self.bottleneck = conv_block(128, 256) |
|
|
| self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) |
| self.decoder3 = conv_block(256, 128) |
|
|
| self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) |
| self.decoder2 = conv_block(128, 64) |
|
|
| self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) |
| self.decoder1 = conv_block(64, 32) |
|
|
| self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1) |
|
|
| def forward(self, x): |
| enc1 = self.encoder1(x) |
| x = self.pool1(enc1) |
| |
| enc2 = self.encoder2(x) |
| x = self.pool2(enc2) |
| |
| enc3 = self.encoder3(x) |
| x = self.pool3(enc3) |
| |
| x = self.bottleneck(x) |
| |
| x = self.up3(x) |
| x = torch.cat([x, enc3], dim=1) |
| x = self.decoder3(x) |
| |
| x = self.up2(x) |
| x = torch.cat([x, enc2], dim=1) |
| x = self.decoder2(x) |
| |
| x = self.up1(x) |
| x = torch.cat([x, enc1], dim=1) |
| x = self.decoder1(x) |
| |
| return self.final_conv(x) |
|
|
| |
| def create_classifier_model(): |
| model = models.resnet152(weights=None) |
| num_ftrs = model.fc.in_features |
| model.fc = nn.Sequential( |
| nn.Linear(num_ftrs, 512), |
| nn.ReLU(), |
| nn.Linear(512, 5), |
| nn.LogSoftmax(dim=1)) |
| return model |
|
|
| @st.cache_resource |
| def load_classifier(): |
| model = create_classifier_model().to(device) |
| checkpoint = torch.load('classifier.pt', map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.eval() |
| return model |
|
|
| def preprocess_classifier(image: Image.Image) -> np.ndarray: |
| img_np = np.array(image) |
| green_channel = img_np[:, :, 1] |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| return np.stack([clahe.apply(green_channel)]*3, axis=-1) |
|
|
| def get_classifier_transform(): |
| return transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
| |
| @st.cache_resource |
| def load_segmenter(): |
| model = UNet().to(device) |
| model.load_state_dict(torch.load('best_unet_model.pth', map_location=device)) |
| model.eval() |
| return model |
|
|
| def preprocess_segmenter(image: Image.Image) -> np.ndarray: |
| img_np = np.array(image) |
| img_filtered = cv2.medianBlur(img_np, 3) |
| lab = cv2.cvtColor(img_filtered, cv2.COLOR_RGB2LAB) |
| l, a, b = cv2.split(lab) |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| lab_clahe = cv2.merge((clahe.apply(l), a, b)) |
| return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB) |
|
|
| def get_segmenter_transform(): |
| return transforms.Compose([ |
| transforms.Resize((512, 512)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| def process_segmentation_output(output: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: |
| probs = torch.softmax(output, dim=1).cpu().numpy().squeeze() |
| pred_class = np.argmax(probs, axis=0) |
| final_mask = pred_class.astype(np.uint8) |
| return final_mask, probs |
|
|
| |
| def create_lesion_overlay(original: Image.Image, mask: np.ndarray) -> Image.Image: |
| original_np = np.array(original) |
| mask_resized = cv2.resize(mask, (original_np.shape[1], original_np.shape[0]), |
| interpolation=cv2.INTER_NEAREST) |
| |
| overlay = original_np.copy() |
| for class_idx, color in LESION_COLORS.items(): |
| overlay[mask_resized == class_idx] = color |
| return Image.fromarray(cv2.addWeighted(overlay, 0.4, original_np, 0.6, 0)) |
|
|
| def segment_image(image: Image.Image, model: nn.Module) -> dict: |
| processed_img = preprocess_segmenter(image) |
| img_pil = Image.fromarray(processed_img) |
| transform = get_segmenter_transform() |
| image_tensor = transform(img_pil).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = model(image_tensor) |
| |
| final_mask, class_probs = process_segmentation_output(output) |
| total_pixels = final_mask.size |
| return { |
| 'mask': final_mask, |
| 'probs': class_probs, |
| 'bright_area': (np.sum(final_mask == 1) / total_pixels * 100), |
| 'red_area': (np.sum(final_mask == 2) / total_pixels * 100) |
| } |
| |
| |
| def generate_pdf_report(original_img: Image.Image, mask: np.ndarray, overlay: Image.Image, |
| diagnosis: str, grade: str, bright_area: float, red_area: float): |
| try: |
| pdf = FPDF() |
| pdf.add_page() |
| |
| |
| pdf.set_font("helvetica", "B", 16) |
| pdf.cell(text="Diabetic Retinopathy Diagnosis Report", new_x="LMARGIN", new_y="NEXT", align='C') |
| pdf.ln(10) |
| |
| pdf.set_font("helvetica", "", 12) |
| pdf.cell(text="Patient: ___________________________", new_x="LMARGIN", new_y="NEXT") |
| pdf.cell(text="Date: _____________________________", new_x="LMARGIN", new_y="NEXT") |
| pdf.ln(10) |
| |
| |
| pdf.set_font("helvetica", "B", 14) |
| pdf.cell(text="Diagnosis:", new_x="LMARGIN", new_y="NEXT") |
| pdf.set_font("helvetica", "", 12) |
| pdf.cell(text=f"Stage: {diagnosis}", new_x="LMARGIN", new_y="NEXT") |
| pdf.cell(text=f"Grading: {grade}", new_x="LMARGIN", new_y="NEXT") |
| pdf.ln(10) |
| |
| |
| pdf.set_font("helvetica", "B", 14) |
| pdf.cell(text="Lesion Analysis:", new_x="LMARGIN", new_y="NEXT") |
| pdf.set_font("helvetica", "", 12) |
| pdf.cell(text=f"Bright Lesions: {bright_area:.2f}%", new_x="LMARGIN", new_y="NEXT") |
| pdf.cell(text=f"Red Lesions: {red_area:.2f}%", new_x="LMARGIN", new_y="NEXT") |
| pdf.cell(text=f"Total Affected Area: {bright_area + red_area:.2f}%", new_x="LMARGIN", new_y="NEXT") |
| pdf.ln(15) |
| |
| |
| pdf.set_font("helvetica", "B", 12) |
| pdf.cell(text="Original Retinal Image:", new_x="LMARGIN", new_y="NEXT") |
| img_byte_arr = io.BytesIO() |
| original_img.save(img_byte_arr, format='PNG') |
| pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) |
| pdf.ln(10) |
| |
| |
| pdf.add_page() |
| |
| |
| pdf.set_font("helvetica", "B", 12) |
| pdf.cell(text="Lesion Segmentation Mask:", new_x="LMARGIN", new_y="NEXT") |
| img_byte_arr = io.BytesIO() |
| Image.fromarray((mask * 85).astype(np.uint8)).save(img_byte_arr, format='PNG') |
| pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) |
| pdf.ln(10) |
| |
| |
| pdf.set_font("helvetica", "B", 12) |
| pdf.cell(text="Lesion Overlay:", new_x="LMARGIN", new_y="NEXT") |
| img_byte_arr = io.BytesIO() |
| overlay.save(img_byte_arr, format='PNG') |
| pdf.image(io.BytesIO(img_byte_arr.getvalue()), x=10, w=100) |
| |
| |
| pdf.ln(10) |
| pdf.set_font("helvetica", "I", 10) |
| pdf.cell(text="This report was generated by DR Analysis System", new_x="LMARGIN", new_y="NEXT", align='C') |
| |
| return bytes(pdf.output()) |
| |
| except Exception as e: |
| st.error(f"PDF generation failed: {str(e)}") |
| return None |
|
|
| |
| def main(): |
| st.set_page_config(layout="wide") |
| st.title("Diabetic Retinopathy Analysis") |
| |
| uploaded_file = st.file_uploader("Upload retinal scan image", |
| type=["jpg", "jpeg", "png"], |
| label_visibility="visible") |
| if not uploaded_file: |
| st.info("Please upload an image") |
| return |
|
|
| try: |
| original_image = Image.open(uploaded_file).convert('RGB') |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| st.image(original_image, caption="Original Image", use_container_width=True) |
|
|
| |
| classifier = load_classifier() |
| clf_processed = preprocess_classifier(original_image) |
| img_tensor = get_classifier_transform()(Image.fromarray(clf_processed)).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| logps = classifier(img_tensor) |
| ps = torch.exp(logps) |
| pred_class = torch.argmax(ps).item() |
| probabilities = ps[0].cpu().numpy() * 100 |
|
|
| st.subheader("Classification Results") |
| predicted_class_name = CLASS_NAMES[pred_class] |
| uk_grade = UK_GRADES[predicted_class_name] |
| |
| if predicted_class_name == "No_DR": |
| st.success(f""" |
| **Prediction:** {predicted_class_name} |
| **Grade:** {uk_grade} |
| """) |
| st.write("No diabetic retinopathy detected - no segmentation needed.") |
| else: |
| st.error(f""" |
| **Prediction:** {predicted_class_name} |
| **Grade:** {uk_grade} |
| """) |
| |
| st.write("**Confidence Levels:**") |
| for name, prob in zip(CLASS_NAMES, probabilities): |
| st.progress(int(prob)) |
| st.write(f"{name}: {prob:.1f}%") |
|
|
| |
| segmenter = load_segmenter() |
| with st.spinner("Detecting lesions..."): |
| seg_results = segment_image(original_image, segmenter) |
| overlay = create_lesion_overlay(original_image, seg_results['mask']) |
|
|
| with col2: |
| st.image(overlay, caption="Lesion Overlay", use_container_width=True) |
|
|
| |
| st.write("**Lesion Analysis:**") |
| cols = st.columns(3) |
| cols[0].metric("Bright Lesions", f"{seg_results['bright_area']:.2f}%") |
| cols[1].metric("Red Lesions", f"{seg_results['red_area']:.2f}%") |
| cols[2].metric("Total Affected", |
| f"{seg_results['bright_area'] + seg_results['red_area']:.2f}%") |
|
|
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.download_button( |
| "Download Mask", |
| cv2.imencode('.png', seg_results['mask'] * 85)[1].tobytes(), |
| "dr_mask.png", |
| "image/png" |
| ) |
| |
| with col2: |
| |
| pdf_bytes = generate_pdf_report( |
| original_image, |
| seg_results['mask'], |
| overlay, |
| predicted_class_name, |
| uk_grade, |
| seg_results['bright_area'], |
| seg_results['red_area'] |
| ) |
| if pdf_bytes is not None: |
| st.download_button( |
| "Download Full Report", |
| data=pdf_bytes, |
| file_name="dr_diagnosis_report.pdf", |
| mime="application/pdf" |
| ) |
| else: |
| st.warning("Failed to generate PDF report") |
|
|
| except Exception as e: |
| st.error(f"Error processing image: {str(e)}") |
|
|
| if __name__ == "__main__": |
| main() |