| |
| import streamlit as st |
| from PIL import Image |
| import torch |
| from model_utils import load_model, predict, generate_gradcam |
| from preprocessing import preprocess_image |
| import numpy as np |
| import cv2 |
|
|
|
|
|
|
|
|
| st.set_page_config(page_title="X-ray Diagnosis Demo", layout="centered") |
| st.title("🩻 X-ray Multi-Label Diagnosis App (CheXNet)") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = load_model(device) |
|
|
|
|
| uploaded_file = st.file_uploader("Upload a chest X-ray", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_file: |
| image = Image.open(uploaded_file).convert("RGB") |
| st.image(image, caption="Uploaded X-ray", use_column_width=True) |
|
|
| img_tensor = preprocess_image(image) |
|
|
| probs = predict(model, img_tensor, device) |
|
|
| |
| top_disease = max(probs, key=probs.get) |
| target_idx = list(probs.keys()).index(top_disease) |
|
|
| |
| cam = generate_gradcam(model, img_tensor, target_idx, device) |
|
|
| |
| image_resized = image.resize((224, 224)) |
| img_np = np.array(image_resized) |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) |
| overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0) |
|
|
| |
| st.subheader(f"Grad-CAM Visualization: {top_disease}") |
| st.image(overlay, use_container_width=True) |
|
|
|
|
| st.subheader("Predictions") |
| for disease, prob in probs.items(): |
| st.write(f"**{disease}**: {prob:.4f}") |
|
|