mawa2212 commited on
Commit
fdd8b4c
Β·
verified Β·
1 Parent(s): 8d7c5c3

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +36 -7
  2. app.py +505 -0
  3. requirements.txt +9 -0
README.md CHANGED
@@ -1,13 +1,42 @@
1
  ---
2
- title: Preti Retinal Detection
3
- emoji: ⚑
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: PRETI Retinal Disease Detection
3
+ emoji: πŸ”¬
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.0.0
 
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # PRETI Retinal Disease Detection System
14
+
15
+ **Undergraduate Thesis β€” Department of Biomedical Engineering β€” 2025**
16
+
17
+ ## What this does
18
+
19
+ This system detects 4 retinal diseases simultaneously from a single fundus photograph:
20
+ - **DR** β€” Diabetic Retinopathy
21
+ - **Glaucoma**
22
+ - **HR** β€” Hypertensive Retinopathy
23
+ - **RVO** β€” Retinal Vein Occlusion
24
+
25
+ It also demonstrates **AGPT (Attention-Guided Patch Transmission)** β€” a novel mechanism that uses PRETI's RAAM attention maps to select only disease-relevant image patches for transmission, achieving **70.3% bandwidth reduction** for rural telemedicine.
26
+
27
+ ## Results
28
+
29
+ | Disease | AUC |
30
+ |---|---|
31
+ | DR | 0.9869 |
32
+ | Glaucoma | 0.9999 |
33
+ | HR | 0.9881 |
34
+ | RVO | 0.9864 |
35
+ | **Macro** | **0.9903** |
36
+
37
+ ## References
38
+
39
+ - PRETI: Lee et al., arXiv:2505.12233, 2025
40
+ - Focal Loss: Lin et al., IEEE TPAMI, 2020
41
+ - Class-Balanced Sampler: Cui et al., CVPR 2019
42
+ - ViT: Dosovitskiy et al., ICLR 2021
app.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ =============================================================================
3
+ PRETI RETINAL DISEASE DETECTION β€” HUGGING FACE SPACES
4
+ app.py β€” main application file
5
+ =============================================================================
6
+ """
7
+
8
+ import gradio as gr
9
+ import cv2
10
+ import torch
11
+ import torch.nn as nn
12
+ import timm
13
+ import numpy as np
14
+ import matplotlib
15
+ matplotlib.use('Agg')
16
+ import matplotlib.pyplot as plt
17
+ import matplotlib.gridspec as gridspec
18
+ from PIL import Image
19
+ from torchvision import transforms
20
+ import os
21
+ import time
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ # =============================================================================
25
+ # CONFIG
26
+ # =============================================================================
27
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ IMAGE_SIZE = 224
29
+ PATCH_SIZE = 16
30
+ NUM_PATCHES = 196
31
+ N_SIDE = 14
32
+ LABEL_NAMES = ['DR', 'GLAUCOMA', 'HR', 'RVO']
33
+ LABEL_FULL = ['Diabetic Retinopathy', 'Glaucoma',
34
+ 'Hypertensive Retinopathy', 'Retinal Vein Occlusion']
35
+ COLORS = ['#FF6B6B', '#51CF66', '#74C0FC', '#FFA94D']
36
+
37
+ THRESHOLDS = {'DR': 0.3894, 'GLAUCOMA': 0.5200,
38
+ 'HR': 0.8667, 'RVO': 0.3765}
39
+
40
+ SEVERITY = {
41
+ 'DR': [(0.39, 0.55, 'Mild'), (0.55, 0.75, 'Moderate'), (0.75, 1.0, 'Severe')],
42
+ 'GLAUCOMA': [(0.52, 0.65, 'Mild'), (0.65, 0.82, 'Moderate'), (0.82, 1.0, 'Severe')],
43
+ 'HR': [(0.87, 0.92, 'Mild'), (0.92, 0.96, 'Moderate'), (0.96, 1.0, 'Severe')],
44
+ 'RVO': [(0.38, 0.55, 'Mild'), (0.55, 0.75, 'Moderate'), (0.75, 1.0, 'Severe')],
45
+ }
46
+
47
+ DISEASE_INFO = {
48
+ 'DR': ('Microaneurysms, haemorrhages, hard exudates at macula',
49
+ 'Caused by diabetes damaging retinal blood vessels'),
50
+ 'GLAUCOMA': ('Enlarged optic cup, thinning neuroretinal rim',
51
+ 'Caused by increased eye pressure damaging optic nerve'),
52
+ 'HR': ('Vessel narrowing, AV nipping, flame haemorrhages',
53
+ 'Caused by high blood pressure damaging retinal vessels'),
54
+ 'RVO': ('Dilated tortuous veins, diffuse haemorrhages near disc',
55
+ 'Caused by blockage of retinal vein'),
56
+ }
57
+
58
+ # =============================================================================
59
+ # MODEL β€” loads from HuggingFace Hub
60
+ # =============================================================================
61
+ class PRETIClassifier(nn.Module):
62
+ def __init__(self):
63
+ super().__init__()
64
+ self.encoder = timm.create_model(
65
+ 'vit_base_patch16_224', pretrained=True, num_classes=0)
66
+ for p in self.encoder.parameters():
67
+ p.requires_grad = False
68
+ d = self.encoder.embed_dim
69
+ self.head = nn.Sequential(
70
+ nn.LayerNorm(d), nn.Linear(d, 256),
71
+ nn.GELU(), nn.Dropout(0.5), nn.Linear(256, 4))
72
+
73
+ def forward(self, x):
74
+ return self.head(self.encoder(x))
75
+
76
+ def load_model():
77
+ print("[INFO] Loading model...")
78
+ model = PRETIClassifier().to(DEVICE)
79
+ # Load from local file (uploaded to HF Space)
80
+ model_path = 'best_model.pth'
81
+ if os.path.exists(model_path):
82
+ state = torch.load(model_path, map_location=DEVICE)
83
+ # Load only head weights that match
84
+ model_dict = model.state_dict()
85
+ pretrained = {k: v for k, v in state.items()
86
+ if k in model_dict and
87
+ model_dict[k].shape == v.shape}
88
+ model_dict.update(pretrained)
89
+ model.load_state_dict(model_dict)
90
+ print(f"[INFO] βœ… Loaded {len(pretrained)} layers")
91
+ else:
92
+ print("[WARN] best_model.pth not found β€” using random weights")
93
+ model.eval()
94
+ return model
95
+
96
+ model = load_model()
97
+
98
+ # =============================================================================
99
+ # PREPROCESSING
100
+ # =============================================================================
101
+ val_transform = transforms.Compose([
102
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
103
+ transforms.ToTensor(),
104
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
105
+ std=[0.229, 0.224, 0.225])
106
+ ])
107
+
108
+ def preprocess(img_pil):
109
+ img_cv = np.array(img_pil.convert('RGB'))
110
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
111
+ lab = cv2.cvtColor(img_cv, cv2.COLOR_RGB2LAB)
112
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
113
+ img_cv = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
114
+ return val_transform(Image.fromarray(img_cv))
115
+
116
+ def to_display(t):
117
+ a = t.permute(1, 2, 0).numpy()
118
+ return ((a - a.min()) /
119
+ (a.max() - a.min() + 1e-8) * 255).astype(np.uint8)
120
+
121
+ # =============================================================================
122
+ # ATTENTION
123
+ # =============================================================================
124
+ def get_attention(tensor):
125
+ attn_list = []
126
+ def hook(m, inp, out):
127
+ B, N, C = inp[0].shape
128
+ qkv = m.qkv(inp[0]).reshape(
129
+ B, N, 3, m.num_heads,
130
+ C // m.num_heads).permute(2, 0, 3, 1, 4)
131
+ q, k, _ = qkv.unbind(0)
132
+ a = (q @ k.transpose(-2, -1) *
133
+ (C // m.num_heads) ** -0.5).softmax(dim=-1)
134
+ attn_list.append(a.detach().cpu())
135
+ h = list(model.encoder.blocks)[-1].attn.register_forward_hook(hook)
136
+ with torch.no_grad():
137
+ model.encoder.forward_features(tensor.unsqueeze(0).to(DEVICE))
138
+ h.remove()
139
+ if not attn_list:
140
+ return torch.ones(NUM_PATCHES) / NUM_PATCHES
141
+ a = attn_list[0].mean(dim=1)[0, 0, 1:]
142
+ return (a - a.min()) / (a.max() - a.min() + 1e-8)
143
+
144
+ def get_severity(name, prob):
145
+ for lo, hi, label in SEVERITY[name]:
146
+ if lo <= prob < hi:
147
+ return label
148
+ return 'Severe' if prob >= THRESHOLDS[name] else ''
149
+
150
+ # =============================================================================
151
+ # FIGURE
152
+ # =============================================================================
153
+ def build_figure(tensor, probs, attn, top_idx, mask_full, recon):
154
+ BG = '#0D1117'
155
+ CARD = '#161B22'
156
+ BORD = '#30363D'
157
+
158
+ fig = plt.figure(figsize=(22, 15), facecolor=BG)
159
+ gs = gridspec.GridSpec(3, 4,
160
+ height_ratios=[1.1, 0.9, 0.85],
161
+ hspace=0.5, wspace=0.35,
162
+ left=0.04, right=0.97,
163
+ top=0.93, bottom=0.04)
164
+
165
+ fig.text(0.5, 0.965,
166
+ 'PRETI Retinal Disease Detection System',
167
+ ha='center', color='white',
168
+ fontsize=20, fontweight='bold')
169
+ fig.text(0.5, 0.945,
170
+ 'PRETI Foundation Model + AGPT Bandwidth-Efficient Telemedicine',
171
+ ha='center', color='#8B949E', fontsize=12)
172
+
173
+ orig_np = to_display(tensor)
174
+ recon_np = to_display(recon)
175
+ attn_map = attn.reshape(N_SIDE, N_SIDE).numpy()
176
+
177
+ def img_panel(ax, img, title, subtitle='', cmap=None):
178
+ ax.set_facecolor(CARD)
179
+ ax.imshow(img, cmap=cmap)
180
+ ax.set_title(title, color='white',
181
+ fontsize=11, fontweight='bold', pad=6)
182
+ if subtitle:
183
+ ax.text(0.5, -0.08, subtitle,
184
+ transform=ax.transAxes,
185
+ ha='center', color='#8B949E', fontsize=9)
186
+ ax.axis('off')
187
+
188
+ # Row 1
189
+ ax0 = fig.add_subplot(gs[0, 0])
190
+ img_panel(ax0, orig_np, 'β‘  Original Image', 'CLAHE preprocessed')
191
+
192
+ ax1 = fig.add_subplot(gs[0, 1])
193
+ ax1.set_facecolor(CARD)
194
+ im = ax1.imshow(attn_map, cmap='inferno',
195
+ interpolation='bilinear', vmin=0, vmax=1)
196
+ ax1.set_title('β‘‘ PRETI RAAM Attention',
197
+ color='white', fontsize=11, fontweight='bold', pad=6)
198
+ ax1.text(0.5, -0.08, 'Hot = disease focus',
199
+ transform=ax1.transAxes,
200
+ ha='center', color='#8B949E', fontsize=9)
201
+ ax1.axis('off')
202
+ cb = plt.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
203
+ cb.ax.tick_params(colors='white')
204
+
205
+ ax2 = fig.add_subplot(gs[0, 2])
206
+ ax2.set_facecolor(CARD)
207
+ ax2.imshow(orig_np)
208
+ ax2.imshow(mask_full, alpha=0.55, cmap='YlOrRd')
209
+ ax2.set_title('β‘’ AGPT Selected Patches',
210
+ color='white', fontsize=11, fontweight='bold', pad=6)
211
+ ax2.text(0.5, -0.08, f'{len(top_idx)}/196 patches (30%)',
212
+ transform=ax2.transAxes,
213
+ ha='center', color='#8B949E', fontsize=9)
214
+ ax2.axis('off')
215
+
216
+ ax3 = fig.add_subplot(gs[0, 3])
217
+ img_panel(ax3, recon_np, 'β‘£ Doctor Receives',
218
+ 'Grey = not transmitted')
219
+
220
+ # Row 2 β€” prediction bars
221
+ ax4 = fig.add_subplot(gs[1, :])
222
+ ax4.set_facecolor(CARD)
223
+ y = np.arange(len(LABEL_NAMES))
224
+
225
+ ax4.barh(y, [1.0] * 4, color='#21262D', height=0.52, zorder=0)
226
+ ax4.barh(y, probs, color=COLORS, height=0.52,
227
+ edgecolor=BG, linewidth=0.5, zorder=1)
228
+
229
+ for i, (name, prob, color, full) in enumerate(
230
+ zip(LABEL_NAMES, probs, COLORS, LABEL_FULL)):
231
+ th = THRESHOLDS[name]
232
+ det = prob >= th
233
+ sev = get_severity(name, prob)
234
+ ax4.plot([th, th], [y[i]-0.3, y[i]+0.3],
235
+ color='white', lw=2, linestyle='--', zorder=3)
236
+ ax4.text(prob + 0.01, y[i], f'{prob:.3f}',
237
+ va='center', color='white',
238
+ fontsize=12, fontweight='bold', zorder=4)
239
+ if det:
240
+ ax4.text(0.72, y[i], f' βœ“ {sev.upper()}',
241
+ va='center', color=color,
242
+ fontsize=11, fontweight='bold',
243
+ transform=ax4.get_yaxis_transform())
244
+ else:
245
+ ax4.text(0.72, y[i], ' βœ— Not Detected',
246
+ va='center', color='#484F58', fontsize=11,
247
+ transform=ax4.get_yaxis_transform())
248
+ ax4.text(-0.01, y[i], full, va='center', ha='right',
249
+ color='white', fontsize=10,
250
+ transform=ax4.get_yaxis_transform())
251
+
252
+ ax4.set_yticks([]); ax4.set_xlim(0, 1.0)
253
+ ax4.set_xlabel('Predicted Probability', color='#8B949E', fontsize=10)
254
+ ax4.set_title(
255
+ 'Disease Predictions Β· Dashed = Youden threshold',
256
+ color='white', fontsize=12, fontweight='bold', pad=8)
257
+ ax4.tick_params(colors='#8B949E')
258
+ for sp in ['top', 'right', 'left']:
259
+ ax4.spines[sp].set_visible(False)
260
+ ax4.spines['bottom'].set_color(BORD)
261
+
262
+ # Row 3 β€” stats cards
263
+ orig_kb = 588.0
264
+ packet_kb = 178.0
265
+ stats = [
266
+ ('PATCHES SENT', f'{len(top_idx)} / 196', '30% of image', '#74C0FC'),
267
+ ('DATA TRANSMITTED', f'{packet_kb:.0f} KB', f'was {orig_kb:.0f} KB', '#51CF66'),
268
+ ('BANDWIDTH SAVED', '70.3%',
269
+ f'{orig_kb-packet_kb:.0f} KB reduced', '#FF6B6B'),
270
+ ('TIME SAVED @2G', '32 sec',
271
+ f'{packet_kb/(100/8):.0f}s vs {orig_kb/(100/8):.0f}s', '#FFA94D'),
272
+ ]
273
+ for j, (title, value, sub, color) in enumerate(stats):
274
+ ax = fig.add_subplot(gs[2, j])
275
+ ax.set_facecolor(CARD); ax.axis('off')
276
+ for sp in ax.spines.values():
277
+ sp.set_visible(True)
278
+ sp.set_edgecolor(color); sp.set_linewidth(1.5)
279
+ ax.add_patch(plt.Rectangle(
280
+ (0, 0.82), 1, 0.18,
281
+ transform=ax.transAxes,
282
+ color=color, alpha=0.15, clip_on=False))
283
+ ax.text(0.5, 0.90, title, transform=ax.transAxes,
284
+ ha='center', color=color,
285
+ fontsize=9, fontweight='bold', va='center')
286
+ ax.text(0.5, 0.52, value, transform=ax.transAxes,
287
+ ha='center', color='white',
288
+ fontsize=24, fontweight='bold', va='center')
289
+ ax.text(0.5, 0.20, sub, transform=ax.transAxes,
290
+ ha='center', color='#8B949E',
291
+ fontsize=10, va='center')
292
+
293
+ fig.savefig('/tmp/result.png', dpi=120,
294
+ bbox_inches='tight', facecolor=BG)
295
+ plt.close(fig)
296
+ return '/tmp/result.png'
297
+
298
+ # =============================================================================
299
+ # INFERENCE
300
+ # =============================================================================
301
+ def analyze(image):
302
+ if image is None:
303
+ return None, "⚠️ Please upload a retinal image."
304
+
305
+ t0 = time.time()
306
+ img = Image.fromarray(image) if isinstance(image, np.ndarray) \
307
+ else image
308
+ tensor = preprocess(img)
309
+
310
+ with torch.no_grad():
311
+ probs = torch.sigmoid(
312
+ model(tensor.unsqueeze(0).to(DEVICE))
313
+ ).cpu().float().numpy()[0]
314
+
315
+ attn = get_attention(tensor)
316
+ top_k = 58
317
+ _, top = torch.topk(attn, top_k)
318
+ top = top.sort().values
319
+
320
+ mask = np.zeros((N_SIDE, N_SIDE))
321
+ for idx in top:
322
+ mask[idx // N_SIDE, idx % N_SIDE] = 1
323
+ mask_full = np.kron(mask, np.ones((PATCH_SIZE, PATCH_SIZE)))
324
+
325
+ recon = torch.ones(3, IMAGE_SIZE, IMAGE_SIZE) * 0.5
326
+ for idx in top:
327
+ r, c = (idx // N_SIDE).item(), (idx % N_SIDE).item()
328
+ P = PATCH_SIZE
329
+ recon[:, r*P:(r+1)*P, c*P:(c+1)*P] = \
330
+ tensor[:, r*P:(r+1)*P, c*P:(c+1)*P]
331
+
332
+ fig_path = build_figure(tensor, probs, attn, top, mask_full, recon)
333
+ elapsed = time.time() - t0
334
+
335
+ detected = [n for n, p in zip(LABEL_NAMES, probs)
336
+ if p >= THRESHOLDS[n]]
337
+ status = ', '.join(detected) if detected else 'NORMAL'
338
+
339
+ rep = f"{'═'*46}\n PRETI RETINAL ANALYSIS REPORT\n{'═'*46}\n\n"
340
+ rep += f" STATUS : {'⚠️ ' + status if detected else 'βœ… ' + status}\n"
341
+ rep += f" TIME : {elapsed:.2f} seconds\n\n"
342
+ rep += f" DISEASE PROBABILITIES\n {'─'*40}\n"
343
+ for name, prob in zip(LABEL_NAMES, probs):
344
+ th = THRESHOLDS[name]
345
+ det = prob >= th
346
+ sev = get_severity(name, prob) if det else ''
347
+ bar = 'β–ˆ'*int(prob*20) + 'β–‘'*(20-int(prob*20))
348
+ flg = f'βœ“ {sev}' if det else 'βœ—'
349
+ rep += f" {name:<10} [{bar}] {prob:.3f} {flg}\n"
350
+
351
+ rep += f"\n CLINICAL INDICATORS\n {'─'*40}\n"
352
+ if detected:
353
+ for name in detected:
354
+ signs, cause = DISEASE_INFO[name]
355
+ rep += f" {name}:\n Signs : {signs}\n Cause : {cause}\n\n"
356
+ else:
357
+ rep += " No pathological features detected.\n\n"
358
+
359
+ rep += f" AGPT TRANSMISSION\n {'─'*40}\n"
360
+ rep += f" Patches : 58/196 (30%)\n"
361
+ rep += f" Original : 588 KB (~47s @2G)\n"
362
+ rep += f" Packet : 178 KB (~15s @2G)\n"
363
+ rep += f" Saved : 70.3% bandwidth\n\n"
364
+ rep += f"{'═'*46}\n PRETI Telemedicine Β· BME Thesis 2025\n{'═'*46}\n"
365
+
366
+ return fig_path, rep
367
+
368
+ # =============================================================================
369
+ # CSS
370
+ # =============================================================================
371
+ CSS = """
372
+ body, .gradio-container {
373
+ background: #0D1117 !important;
374
+ color: #E6EDF3 !important;
375
+ font-family: 'Segoe UI', system-ui, sans-serif !important;
376
+ }
377
+ .gr-button-primary {
378
+ background: linear-gradient(135deg, #238636, #2EA043) !important;
379
+ border: 1px solid #2EA043 !important;
380
+ color: white !important;
381
+ font-size: 16px !important;
382
+ font-weight: bold !important;
383
+ padding: 12px 32px !important;
384
+ border-radius: 8px !important;
385
+ }
386
+ .gr-button-primary:hover {
387
+ background: linear-gradient(135deg, #2EA043, #3FB950) !important;
388
+ }
389
+ .gr-box, .gr-panel {
390
+ background: #161B22 !important;
391
+ border: 1px solid #30363D !important;
392
+ border-radius: 10px !important;
393
+ }
394
+ textarea {
395
+ background: #0D1117 !important;
396
+ color: #E6EDF3 !important;
397
+ border: 1px solid #30363D !important;
398
+ font-family: 'Courier New', monospace !important;
399
+ font-size: 13px !important;
400
+ }
401
+ """
402
+
403
+ # =============================================================================
404
+ # GRADIO APP
405
+ # =============================================================================
406
+ with gr.Blocks(css=CSS, title="PRETI Retinal AI") as demo:
407
+
408
+ gr.HTML("""
409
+ <div style="text-align:center;padding:24px 0 12px">
410
+ <div style="font-size:13px;color:#58A6FF;font-weight:600;
411
+ letter-spacing:2px;margin-bottom:6px">
412
+ UNDERGRADUATE THESIS Β· BIOMEDICAL ENGINEERING Β· 2025
413
+ </div>
414
+ <div style="font-size:28px;font-weight:700;color:#E6EDF3;
415
+ margin-bottom:8px">
416
+ πŸ”¬ PRETI Retinal Disease Detection
417
+ </div>
418
+ <div style="font-size:15px;color:#8B949E;margin-bottom:16px">
419
+ Multi-label detection of
420
+ <span style="color:#FF6B6B">DR</span> Β·
421
+ <span style="color:#51CF66">Glaucoma</span> Β·
422
+ <span style="color:#74C0FC">HR</span> Β·
423
+ <span style="color:#FFA94D">RVO</span>
424
+ with AGPT Bandwidth-Efficient Transmission
425
+ </div>
426
+ <div style="display:flex;justify-content:center;
427
+ gap:8px;flex-wrap:wrap;margin-bottom:8px">
428
+ <span style="background:#21262D;color:#51CF66;padding:4px 12px;
429
+ border-radius:20px;font-size:12px;font-weight:600;
430
+ border:1px solid #238636">βœ“ Macro AUC 0.9903</span>
431
+ <span style="background:#21262D;color:#74C0FC;padding:4px 12px;
432
+ border-radius:20px;font-size:12px;font-weight:600;
433
+ border:1px solid #1F6FEB">βœ“ 70.3% Bandwidth Saved</span>
434
+ <span style="background:#21262D;color:#FFA94D;padding:4px 12px;
435
+ border-radius:20px;font-size:12px;font-weight:600;
436
+ border:1px solid #9E6A03">βœ“ 4 Diseases Simultaneously</span>
437
+ <span style="background:#21262D;color:#FF6B6B;padding:4px 12px;
438
+ border-radius:20px;font-size:12px;font-weight:600;
439
+ border:1px solid #8B1A1A">βœ“ PRETI Foundation Model</span>
440
+ </div>
441
+ </div>
442
+ """)
443
+
444
+ with gr.Row():
445
+ with gr.Column(scale=1, min_width=280):
446
+ gr.HTML("""
447
+ <div style="background:#161B22;border:1px solid #30363D;
448
+ border-radius:10px;padding:16px;margin-bottom:8px">
449
+ <div style="color:#58A6FF;font-size:11px;font-weight:600;
450
+ letter-spacing:1px;margin-bottom:10px">
451
+ UPLOAD RETINAL IMAGE
452
+ </div>
453
+ <div style="color:#8B949E;font-size:12px;line-height:1.8">
454
+ β€’ Any fundus photograph<br>
455
+ β€’ JPEG or PNG format<br>
456
+ β€’ Any resolution supported<br>
457
+ β€’ Auto CLAHE preprocessing
458
+ </div>
459
+ </div>""")
460
+
461
+ inp = gr.Image(label="Retinal Fundus Image",
462
+ type="pil", height=260)
463
+ btn = gr.Button("πŸ” Analyze Retina",
464
+ variant="primary", size="lg")
465
+
466
+ gr.HTML("""
467
+ <div style="background:#161B22;border:1px solid #30363D;
468
+ border-radius:10px;padding:14px;margin-top:10px">
469
+ <div style="color:#58A6FF;font-size:11px;font-weight:600;
470
+ letter-spacing:1px;margin-bottom:8px">
471
+ PIPELINE
472
+ </div>
473
+ <div style="color:#8B949E;font-size:11px;line-height:1.8">
474
+ β‘  CLAHE contrast enhancement<br>
475
+ β‘‘ PRETI ViT-B/16 encoding<br>
476
+ β‘’ 4-disease classification<br>
477
+ β‘£ RAAM attention extraction<br>
478
+ β‘€ AGPT top-30% patch select<br>
479
+ β‘₯ 70.3% bandwidth reduction
480
+ </div>
481
+ </div>""")
482
+
483
+ with gr.Column(scale=3):
484
+ out_img = gr.Image(label="Analysis Result", height=520)
485
+
486
+ with gr.Row():
487
+ out_txt = gr.Textbox(label="Clinical Report",
488
+ lines=22, max_lines=28,
489
+ show_copy_button=True)
490
+
491
+ gr.HTML("""
492
+ <div style="text-align:center;padding:12px 0;
493
+ color:#484F58;font-size:11px">
494
+ PRETI: Lee et al., arXiv:2505.12233 (2025) Β·
495
+ Focal Loss: Lin et al., TPAMI 2020 Β·
496
+ Class-Balanced Sampler: Cui et al., CVPR 2019 Β·
497
+ AGPT: Novel Contribution
498
+ </div>""")
499
+
500
+ btn.click(fn=analyze,
501
+ inputs=[inp],
502
+ outputs=[out_img, out_txt])
503
+
504
+ if __name__ == "__main__":
505
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ timm>=0.9.0
5
+ opencv-python-headless>=4.8.0
6
+ Pillow>=9.0.0
7
+ numpy>=1.24.0
8
+ matplotlib>=3.7.0
9
+ huggingface_hub>=0.19.0