DACNet / app.py
cfgpp's picture
Update app.py
26b2386 verified
# app.py
import streamlit as st
from PIL import Image
import torch
from model_utils import load_model, predict, generate_gradcam
from preprocessing import preprocess_image
import numpy as np
import cv2
st.set_page_config(page_title="X-ray Diagnosis Demo", layout="centered")
st.title("🩻 X-ray Multi-Label Diagnosis App (CheXNet)")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(device)
uploaded_file = st.file_uploader("Upload a chest X-ray", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded X-ray", use_column_width=True)
img_tensor = preprocess_image(image)
probs = predict(model, img_tensor, device)
# Get top class
top_disease = max(probs, key=probs.get)
target_idx = list(probs.keys()).index(top_disease)
# Grad-CAM
cam = generate_gradcam(model, img_tensor, target_idx, device)
# Overlay on image
image_resized = image.resize((224, 224))
img_np = np.array(image_resized)
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
# Show it
st.subheader(f"Grad-CAM Visualization: {top_disease}")
st.image(overlay, use_container_width=True)
st.subheader("Predictions")
for disease, prob in probs.items():
st.write(f"**{disease}**: {prob:.4f}")