Spaces:
Runtime error
Runtime error
| """ | |
| ============================================================================= | |
| PRETI RETINAL DISEASE DETECTION — HUGGING FACE SPACES | |
| app.py — main application file | |
| ============================================================================= | |
| """ | |
| import gradio as gr | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import matplotlib.gridspec as gridspec | |
| from PIL import Image | |
| from torchvision import transforms | |
| import os | |
| import time | |
| from huggingface_hub import hf_hub_download | |
| # ============================================================================= | |
| # CONFIG | |
| # ============================================================================= | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| IMAGE_SIZE = 224 | |
| PATCH_SIZE = 16 | |
| NUM_PATCHES = 196 | |
| N_SIDE = 14 | |
| LABEL_NAMES = ['DR', 'GLAUCOMA', 'HR', 'RVO'] | |
| LABEL_FULL = ['Diabetic Retinopathy', 'Glaucoma', | |
| 'Hypertensive Retinopathy', 'Retinal Vein Occlusion'] | |
| COLORS = ['#FF6B6B', '#51CF66', '#74C0FC', '#FFA94D'] | |
| THRESHOLDS = {'DR': 0.3894, 'GLAUCOMA': 0.5200, | |
| 'HR': 0.8667, 'RVO': 0.3765} | |
| SEVERITY = { | |
| 'DR': [(0.39, 0.55, 'Mild'), (0.55, 0.75, 'Moderate'), (0.75, 1.0, 'Severe')], | |
| 'GLAUCOMA': [(0.52, 0.65, 'Mild'), (0.65, 0.82, 'Moderate'), (0.82, 1.0, 'Severe')], | |
| 'HR': [(0.87, 0.92, 'Mild'), (0.92, 0.96, 'Moderate'), (0.96, 1.0, 'Severe')], | |
| 'RVO': [(0.38, 0.55, 'Mild'), (0.55, 0.75, 'Moderate'), (0.75, 1.0, 'Severe')], | |
| } | |
| DISEASE_INFO = { | |
| 'DR': ('Microaneurysms, haemorrhages, hard exudates at macula', | |
| 'Caused by diabetes damaging retinal blood vessels'), | |
| 'GLAUCOMA': ('Enlarged optic cup, thinning neuroretinal rim', | |
| 'Caused by increased eye pressure damaging optic nerve'), | |
| 'HR': ('Vessel narrowing, AV nipping, flame haemorrhages', | |
| 'Caused by high blood pressure damaging retinal vessels'), | |
| 'RVO': ('Dilated tortuous veins, diffuse haemorrhages near disc', | |
| 'Caused by blockage of retinal vein'), | |
| } | |
| # ============================================================================= | |
| # MODEL — loads from HuggingFace Hub | |
| # ============================================================================= | |
| class PRETIClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = timm.create_model( | |
| 'vit_base_patch16_224', pretrained=True, num_classes=0) | |
| for p in self.encoder.parameters(): | |
| p.requires_grad = False | |
| d = self.encoder.embed_dim | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(d), nn.Linear(d, 256), | |
| nn.GELU(), nn.Dropout(0.5), nn.Linear(256, 4)) | |
| def forward(self, x): | |
| return self.head(self.encoder(x)) | |
| def load_model(): | |
| print("[INFO] Loading model...") | |
| model = PRETIClassifier().to(DEVICE) | |
| # Load from local file (uploaded to HF Space) | |
| model_path = 'best_model.pth' | |
| if os.path.exists(model_path): | |
| state = torch.load(model_path, map_location=DEVICE) | |
| # Load only head weights that match | |
| model_dict = model.state_dict() | |
| pretrained = {k: v for k, v in state.items() | |
| if k in model_dict and | |
| model_dict[k].shape == v.shape} | |
| model_dict.update(pretrained) | |
| model.load_state_dict(model_dict) | |
| print(f"[INFO] ✅ Loaded {len(pretrained)} layers") | |
| else: | |
| print("[WARN] best_model.pth not found — using random weights") | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # ============================================================================= | |
| # PREPROCESSING | |
| # ============================================================================= | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def preprocess(img_pil): | |
| img_cv = np.array(img_pil.convert('RGB')) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| lab = cv2.cvtColor(img_cv, cv2.COLOR_RGB2LAB) | |
| lab[:, :, 0] = clahe.apply(lab[:, :, 0]) | |
| img_cv = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB) | |
| return val_transform(Image.fromarray(img_cv)) | |
| def to_display(t): | |
| a = t.permute(1, 2, 0).numpy() | |
| return ((a - a.min()) / | |
| (a.max() - a.min() + 1e-8) * 255).astype(np.uint8) | |
| # ============================================================================= | |
| # ATTENTION | |
| # ============================================================================= | |
| def get_attention(tensor): | |
| attn_list = [] | |
| def hook(m, inp, out): | |
| B, N, C = inp[0].shape | |
| qkv = m.qkv(inp[0]).reshape( | |
| B, N, 3, m.num_heads, | |
| C // m.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, _ = qkv.unbind(0) | |
| a = (q @ k.transpose(-2, -1) * | |
| (C // m.num_heads) ** -0.5).softmax(dim=-1) | |
| attn_list.append(a.detach().cpu()) | |
| h = list(model.encoder.blocks)[-1].attn.register_forward_hook(hook) | |
| with torch.no_grad(): | |
| model.encoder.forward_features(tensor.unsqueeze(0).to(DEVICE)) | |
| h.remove() | |
| if not attn_list: | |
| return torch.ones(NUM_PATCHES) / NUM_PATCHES | |
| a = attn_list[0].mean(dim=1)[0, 0, 1:] | |
| return (a - a.min()) / (a.max() - a.min() + 1e-8) | |
| def get_severity(name, prob): | |
| for lo, hi, label in SEVERITY[name]: | |
| if lo <= prob < hi: | |
| return label | |
| return 'Severe' if prob >= THRESHOLDS[name] else '' | |
| # ============================================================================= | |
| # FIGURE | |
| # ============================================================================= | |
| def build_figure(tensor, probs, attn, top_idx, mask_full, recon): | |
| BG = '#0D1117' | |
| CARD = '#161B22' | |
| BORD = '#30363D' | |
| fig = plt.figure(figsize=(22, 15), facecolor=BG) | |
| gs = gridspec.GridSpec(3, 4, | |
| height_ratios=[1.1, 0.9, 0.85], | |
| hspace=0.5, wspace=0.35, | |
| left=0.04, right=0.97, | |
| top=0.93, bottom=0.04) | |
| fig.text(0.5, 0.965, | |
| 'PRETI Retinal Disease Detection System', | |
| ha='center', color='white', | |
| fontsize=20, fontweight='bold') | |
| fig.text(0.5, 0.945, | |
| 'PRETI Foundation Model + AGPT Bandwidth-Efficient Telemedicine', | |
| ha='center', color='#8B949E', fontsize=12) | |
| orig_np = to_display(tensor) | |
| recon_np = to_display(recon) | |
| attn_map = attn.reshape(N_SIDE, N_SIDE).numpy() | |
| def img_panel(ax, img, title, subtitle='', cmap=None): | |
| ax.set_facecolor(CARD) | |
| ax.imshow(img, cmap=cmap) | |
| ax.set_title(title, color='white', | |
| fontsize=11, fontweight='bold', pad=6) | |
| if subtitle: | |
| ax.text(0.5, -0.08, subtitle, | |
| transform=ax.transAxes, | |
| ha='center', color='#8B949E', fontsize=9) | |
| ax.axis('off') | |
| # Row 1 | |
| ax0 = fig.add_subplot(gs[0, 0]) | |
| img_panel(ax0, orig_np, '① Original Image', 'CLAHE preprocessed') | |
| ax1 = fig.add_subplot(gs[0, 1]) | |
| ax1.set_facecolor(CARD) | |
| im = ax1.imshow(attn_map, cmap='inferno', | |
| interpolation='bilinear', vmin=0, vmax=1) | |
| ax1.set_title('② PRETI RAAM Attention', | |
| color='white', fontsize=11, fontweight='bold', pad=6) | |
| ax1.text(0.5, -0.08, 'Hot = disease focus', | |
| transform=ax1.transAxes, | |
| ha='center', color='#8B949E', fontsize=9) | |
| ax1.axis('off') | |
| cb = plt.colorbar(im, ax=ax1, fraction=0.046, pad=0.04) | |
| cb.ax.tick_params(colors='white') | |
| ax2 = fig.add_subplot(gs[0, 2]) | |
| ax2.set_facecolor(CARD) | |
| ax2.imshow(orig_np) | |
| ax2.imshow(mask_full, alpha=0.55, cmap='YlOrRd') | |
| ax2.set_title('③ AGPT Selected Patches', | |
| color='white', fontsize=11, fontweight='bold', pad=6) | |
| ax2.text(0.5, -0.08, f'{len(top_idx)}/196 patches (30%)', | |
| transform=ax2.transAxes, | |
| ha='center', color='#8B949E', fontsize=9) | |
| ax2.axis('off') | |
| ax3 = fig.add_subplot(gs[0, 3]) | |
| img_panel(ax3, recon_np, '④ Doctor Receives', | |
| 'Grey = not transmitted') | |
| # Row 2 — prediction bars | |
| ax4 = fig.add_subplot(gs[1, :]) | |
| ax4.set_facecolor(CARD) | |
| y = np.arange(len(LABEL_NAMES)) | |
| ax4.barh(y, [1.0] * 4, color='#21262D', height=0.52, zorder=0) | |
| ax4.barh(y, probs, color=COLORS, height=0.52, | |
| edgecolor=BG, linewidth=0.5, zorder=1) | |
| for i, (name, prob, color, full) in enumerate( | |
| zip(LABEL_NAMES, probs, COLORS, LABEL_FULL)): | |
| th = THRESHOLDS[name] | |
| det = prob >= th | |
| sev = get_severity(name, prob) | |
| ax4.plot([th, th], [y[i]-0.3, y[i]+0.3], | |
| color='white', lw=2, linestyle='--', zorder=3) | |
| ax4.text(prob + 0.01, y[i], f'{prob:.3f}', | |
| va='center', color='white', | |
| fontsize=12, fontweight='bold', zorder=4) | |
| if det: | |
| ax4.text(0.72, y[i], f' ✓ {sev.upper()}', | |
| va='center', color=color, | |
| fontsize=11, fontweight='bold', | |
| transform=ax4.get_yaxis_transform()) | |
| else: | |
| ax4.text(0.72, y[i], ' ✗ Not Detected', | |
| va='center', color='#484F58', fontsize=11, | |
| transform=ax4.get_yaxis_transform()) | |
| ax4.text(-0.01, y[i], full, va='center', ha='right', | |
| color='white', fontsize=10, | |
| transform=ax4.get_yaxis_transform()) | |
| ax4.set_yticks([]); ax4.set_xlim(0, 1.0) | |
| ax4.set_xlabel('Predicted Probability', color='#8B949E', fontsize=10) | |
| ax4.set_title( | |
| 'Disease Predictions · Dashed = Youden threshold', | |
| color='white', fontsize=12, fontweight='bold', pad=8) | |
| ax4.tick_params(colors='#8B949E') | |
| for sp in ['top', 'right', 'left']: | |
| ax4.spines[sp].set_visible(False) | |
| ax4.spines['bottom'].set_color(BORD) | |
| # Row 3 — stats cards | |
| orig_kb = 588.0 | |
| packet_kb = 178.0 | |
| stats = [ | |
| ('PATCHES SENT', f'{len(top_idx)} / 196', '30% of image', '#74C0FC'), | |
| ('DATA TRANSMITTED', f'{packet_kb:.0f} KB', f'was {orig_kb:.0f} KB', '#51CF66'), | |
| ('BANDWIDTH SAVED', '70.3%', | |
| f'{orig_kb-packet_kb:.0f} KB reduced', '#FF6B6B'), | |
| ('TIME SAVED @2G', '32 sec', | |
| f'{packet_kb/(100/8):.0f}s vs {orig_kb/(100/8):.0f}s', '#FFA94D'), | |
| ] | |
| for j, (title, value, sub, color) in enumerate(stats): | |
| ax = fig.add_subplot(gs[2, j]) | |
| ax.set_facecolor(CARD); ax.axis('off') | |
| for sp in ax.spines.values(): | |
| sp.set_visible(True) | |
| sp.set_edgecolor(color); sp.set_linewidth(1.5) | |
| ax.add_patch(plt.Rectangle( | |
| (0, 0.82), 1, 0.18, | |
| transform=ax.transAxes, | |
| color=color, alpha=0.15, clip_on=False)) | |
| ax.text(0.5, 0.90, title, transform=ax.transAxes, | |
| ha='center', color=color, | |
| fontsize=9, fontweight='bold', va='center') | |
| ax.text(0.5, 0.52, value, transform=ax.transAxes, | |
| ha='center', color='white', | |
| fontsize=24, fontweight='bold', va='center') | |
| ax.text(0.5, 0.20, sub, transform=ax.transAxes, | |
| ha='center', color='#8B949E', | |
| fontsize=10, va='center') | |
| fig.savefig('/tmp/result.png', dpi=120, | |
| bbox_inches='tight', facecolor=BG) | |
| plt.close(fig) | |
| return '/tmp/result.png' | |
| # ============================================================================= | |
| # INFERENCE | |
| # ============================================================================= | |
| def analyze(image): | |
| if image is None: | |
| return None, "⚠️ Please upload a retinal image." | |
| t0 = time.time() | |
| img = Image.fromarray(image) if isinstance(image, np.ndarray) \ | |
| else image | |
| tensor = preprocess(img) | |
| with torch.no_grad(): | |
| probs = torch.sigmoid( | |
| model(tensor.unsqueeze(0).to(DEVICE)) | |
| ).cpu().float().numpy()[0] | |
| attn = get_attention(tensor) | |
| top_k = 58 | |
| _, top = torch.topk(attn, top_k) | |
| top = top.sort().values | |
| mask = np.zeros((N_SIDE, N_SIDE)) | |
| for idx in top: | |
| mask[idx // N_SIDE, idx % N_SIDE] = 1 | |
| mask_full = np.kron(mask, np.ones((PATCH_SIZE, PATCH_SIZE))) | |
| recon = torch.ones(3, IMAGE_SIZE, IMAGE_SIZE) * 0.5 | |
| for idx in top: | |
| r, c = (idx // N_SIDE).item(), (idx % N_SIDE).item() | |
| P = PATCH_SIZE | |
| recon[:, r*P:(r+1)*P, c*P:(c+1)*P] = \ | |
| tensor[:, r*P:(r+1)*P, c*P:(c+1)*P] | |
| fig_path = build_figure(tensor, probs, attn, top, mask_full, recon) | |
| elapsed = time.time() - t0 | |
| detected = [n for n, p in zip(LABEL_NAMES, probs) | |
| if p >= THRESHOLDS[n]] | |
| status = ', '.join(detected) if detected else 'NORMAL' | |
| rep = f"{'═'*46}\n PRETI RETINAL ANALYSIS REPORT\n{'═'*46}\n\n" | |
| rep += f" STATUS : {'⚠️ ' + status if detected else '✅ ' + status}\n" | |
| rep += f" TIME : {elapsed:.2f} seconds\n\n" | |
| rep += f" DISEASE PROBABILITIES\n {'─'*40}\n" | |
| for name, prob in zip(LABEL_NAMES, probs): | |
| th = THRESHOLDS[name] | |
| det = prob >= th | |
| sev = get_severity(name, prob) if det else '' | |
| bar = '█'*int(prob*20) + '░'*(20-int(prob*20)) | |
| flg = f'✓ {sev}' if det else '✗' | |
| rep += f" {name:<10} [{bar}] {prob:.3f} {flg}\n" | |
| rep += f"\n CLINICAL INDICATORS\n {'─'*40}\n" | |
| if detected: | |
| for name in detected: | |
| signs, cause = DISEASE_INFO[name] | |
| rep += f" {name}:\n Signs : {signs}\n Cause : {cause}\n\n" | |
| else: | |
| rep += " No pathological features detected.\n\n" | |
| rep += f" AGPT TRANSMISSION\n {'─'*40}\n" | |
| rep += f" Patches : 58/196 (30%)\n" | |
| rep += f" Original : 588 KB (~47s @2G)\n" | |
| rep += f" Packet : 178 KB (~15s @2G)\n" | |
| rep += f" Saved : 70.3% bandwidth\n\n" | |
| rep += f"{'═'*46}\n PRETI Telemedicine · BME Thesis 2025\n{'═'*46}\n" | |
| return fig_path, rep | |
| # ============================================================================= | |
| # CSS | |
| # ============================================================================= | |
| CSS = """ | |
| body, .gradio-container { | |
| background: #0D1117 !important; | |
| color: #E6EDF3 !important; | |
| font-family: 'Segoe UI', system-ui, sans-serif !important; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #238636, #2EA043) !important; | |
| border: 1px solid #2EA043 !important; | |
| color: white !important; | |
| font-size: 16px !important; | |
| font-weight: bold !important; | |
| padding: 12px 32px !important; | |
| border-radius: 8px !important; | |
| } | |
| .gr-button-primary:hover { | |
| background: linear-gradient(135deg, #2EA043, #3FB950) !important; | |
| } | |
| .gr-box, .gr-panel { | |
| background: #161B22 !important; | |
| border: 1px solid #30363D !important; | |
| border-radius: 10px !important; | |
| } | |
| textarea { | |
| background: #0D1117 !important; | |
| color: #E6EDF3 !important; | |
| border: 1px solid #30363D !important; | |
| font-family: 'Courier New', monospace !important; | |
| font-size: 13px !important; | |
| } | |
| """ | |
| # ============================================================================= | |
| # GRADIO APP | |
| # ============================================================================= | |
| with gr.Blocks(css=CSS, title="PRETI Retinal AI") as demo: | |
| gr.HTML(""" | |
| <div style="text-align:center;padding:24px 0 12px"> | |
| <div style="font-size:13px;color:#58A6FF;font-weight:600; | |
| letter-spacing:2px;margin-bottom:6px"> | |
| UNDERGRADUATE THESIS · BIOMEDICAL ENGINEERING · 2025 | |
| </div> | |
| <div style="font-size:28px;font-weight:700;color:#E6EDF3; | |
| margin-bottom:8px"> | |
| 🔬 PRETI Retinal Disease Detection | |
| </div> | |
| <div style="font-size:15px;color:#8B949E;margin-bottom:16px"> | |
| Multi-label detection of | |
| <span style="color:#FF6B6B">DR</span> · | |
| <span style="color:#51CF66">Glaucoma</span> · | |
| <span style="color:#74C0FC">HR</span> · | |
| <span style="color:#FFA94D">RVO</span> | |
| with AGPT Bandwidth-Efficient Transmission | |
| </div> | |
| <div style="display:flex;justify-content:center; | |
| gap:8px;flex-wrap:wrap;margin-bottom:8px"> | |
| <span style="background:#21262D;color:#51CF66;padding:4px 12px; | |
| border-radius:20px;font-size:12px;font-weight:600; | |
| border:1px solid #238636">✓ Macro AUC 0.9903</span> | |
| <span style="background:#21262D;color:#74C0FC;padding:4px 12px; | |
| border-radius:20px;font-size:12px;font-weight:600; | |
| border:1px solid #1F6FEB">✓ 70.3% Bandwidth Saved</span> | |
| <span style="background:#21262D;color:#FFA94D;padding:4px 12px; | |
| border-radius:20px;font-size:12px;font-weight:600; | |
| border:1px solid #9E6A03">✓ 4 Diseases Simultaneously</span> | |
| <span style="background:#21262D;color:#FF6B6B;padding:4px 12px; | |
| border-radius:20px;font-size:12px;font-weight:600; | |
| border:1px solid #8B1A1A">✓ PRETI Foundation Model</span> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=280): | |
| gr.HTML(""" | |
| <div style="background:#161B22;border:1px solid #30363D; | |
| border-radius:10px;padding:16px;margin-bottom:8px"> | |
| <div style="color:#58A6FF;font-size:11px;font-weight:600; | |
| letter-spacing:1px;margin-bottom:10px"> | |
| UPLOAD RETINAL IMAGE | |
| </div> | |
| <div style="color:#8B949E;font-size:12px;line-height:1.8"> | |
| • Any fundus photograph<br> | |
| • JPEG or PNG format<br> | |
| • Any resolution supported<br> | |
| • Auto CLAHE preprocessing | |
| </div> | |
| </div>""") | |
| inp = gr.Image(label="Retinal Fundus Image", | |
| type="pil", height=260) | |
| btn = gr.Button("🔍 Analyze Retina", | |
| variant="primary", size="lg") | |
| gr.HTML(""" | |
| <div style="background:#161B22;border:1px solid #30363D; | |
| border-radius:10px;padding:14px;margin-top:10px"> | |
| <div style="color:#58A6FF;font-size:11px;font-weight:600; | |
| letter-spacing:1px;margin-bottom:8px"> | |
| PIPELINE | |
| </div> | |
| <div style="color:#8B949E;font-size:11px;line-height:1.8"> | |
| ① CLAHE contrast enhancement<br> | |
| ② PRETI ViT-B/16 encoding<br> | |
| ③ 4-disease classification<br> | |
| ④ RAAM attention extraction<br> | |
| ⑤ AGPT top-30% patch select<br> | |
| ⑥ 70.3% bandwidth reduction | |
| </div> | |
| </div>""") | |
| with gr.Column(scale=3): | |
| out_img = gr.Image(label="Analysis Result", height=520) | |
| with gr.Row(): | |
| out_txt = gr.Textbox(label="Clinical Report", | |
| lines=22, max_lines=28, | |
| show_copy_button=True) | |
| gr.HTML(""" | |
| <div style="text-align:center;padding:12px 0; | |
| color:#484F58;font-size:11px"> | |
| PRETI: Lee et al., arXiv:2505.12233 (2025) · | |
| Focal Loss: Lin et al., TPAMI 2020 · | |
| Class-Balanced Sampler: Cui et al., CVPR 2019 · | |
| AGPT: Novel Contribution | |
| </div>""") | |
| btn.click(fn=analyze, | |
| inputs=[inp], | |
| outputs=[out_img, out_txt]) | |
| if __name__ == "__main__": | |
| demo.launch() | |